diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index b7a8d9a9..9caeca6c 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -81,7 +81,7 @@ jobs: - name: Install Geckodriver / Selenium run: | - GECKO_VER=0.30.0 + GECKO_VER=0.33.0 CACHED_DOWNLOAD_DIR=~/.local/downloads FILENAME=geckodriver-v${GECKO_VER}-linux64.tar.gz diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3138b227..75f19303 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -23,7 +23,7 @@ repos: pass_filenames: true exclude: baselayer|node_modules|static - repo: https://github.com/pycqa/flake8 - rev: 3.8.4 + rev: 6.1.0 hooks: - id: flake8 pass_filenames: true diff --git a/app/access.py b/app/access.py index 8d13b5b5..dd25ddb6 100644 --- a/app/access.py +++ b/app/access.py @@ -5,7 +5,7 @@ from sqlalchemy.orm import joinedload from baselayer.app.custom_exceptions import AccessError # noqa: F401 -from baselayer.app.models import DBSession, Role, Token, User # noqa: F401 +from baselayer.app.models import HandlerSession, Role, Token, User # noqa: F401 def auth_or_token(method): @@ -26,7 +26,7 @@ def wrapper(self, *args, **kwargs): token_header = self.request.headers.get("Authorization", None) if token_header is not None and token_header.startswith("token "): token_id = token_header.replace("token", "").strip() - with DBSession() as session: + with HandlerSession() as session: token = session.scalars( sa.select(Token) .options( diff --git a/app/handlers/base.py b/app/handlers/base.py index 6da53dcc..b761b25a 100644 --- a/app/handlers/base.py +++ b/app/handlers/base.py @@ -22,7 +22,13 @@ from ..env import load_env from ..flow import Flow from ..json_util import to_json -from ..models import DBSession, User, VerifiedSession, bulk_verify, session_context_id +from ..models import ( + HandlerSession, + User, + VerifiedSession, + bulk_verify, + session_context_id, +) env, cfg = load_env() log = make_log("basehandler") @@ -49,7 +55,7 @@ def get_current_user(self): user_id = int(self.user_id()) oauth_uid = self.get_secure_cookie("user_oauth_uid") if user_id and oauth_uid: - with DBSession() as session: + with HandlerSession() as session: try: user = session.scalars( sqlalchemy.select(User).where(User.id == user_id) @@ -74,7 +80,7 @@ def get_current_user(self): return None def login_user(self, user): - with DBSession() as session: + with HandlerSession() as session: try: self.set_secure_cookie("user_id", str(user.id)) user = session.scalars( @@ -120,7 +126,7 @@ def log_exception(self, typ=None, value=None, tb=None): ) def on_finish(self): - DBSession.remove() + HandlerSession.remove() class BaseHandler(PSABaseHandler): @@ -153,7 +159,7 @@ def Session(self): # must merge the user object with the current session # ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#adding-new-or-existing-items session.add(self.current_user) - session.bind = DBSession.session_factory.kw["bind"] + session.bind = HandlerSession.engine yield session def verify_permissions(self): @@ -164,20 +170,20 @@ def verify_permissions(self): """ # get items to be inserted - new_rows = [row for row in DBSession().new] + new_rows = [row for row in HandlerSession().new] # get items to be updated updated_rows = [ - row for row in DBSession().dirty if DBSession().is_modified(row) + row for row in HandlerSession().dirty if HandlerSession().is_modified(row) ] # get items to be deleted - deleted_rows = [row for row in DBSession().deleted] + deleted_rows = [row for row in HandlerSession().deleted] # get items that were read read_rows = [ row - for row in set(DBSession().identity_map.values()) + for row in set(HandlerSession().identity_map.values()) - (set(updated_rows) | set(new_rows) | set(deleted_rows)) ] @@ -194,7 +200,7 @@ def verify_permissions(self): # update transaction state in DB, but don't commit yet. this updates # or adds rows in the database and uses their new state in joins, # for permissions checking purposes. - DBSession().flush() + HandlerSession().flush() bulk_verify("create", new_rows, self.current_user) def verify_and_commit(self): @@ -202,7 +208,7 @@ def verify_and_commit(self): successful, otherwise raise an AccessError. """ self.verify_permissions() - DBSession().commit() + HandlerSession().commit() def prepare(self): self.cfg = self.application.cfg @@ -225,7 +231,7 @@ def prepare(self): N = 5 for i in range(1, N + 1): try: - assert DBSession.session_factory.kw["bind"] is not None + assert HandlerSession.engine is not None except Exception as e: if i == N: raise e diff --git a/app/model_util.py b/app/model_util.py index 6a52f583..037371bb 100644 --- a/app/model_util.py +++ b/app/model_util.py @@ -21,15 +21,15 @@ def status(message): else: print(f"\r[✓] {message}") finally: - models.DBSession().commit() + models.HandlerSession().commit() def drop_tables(): - conn = models.DBSession.session_factory.kw["bind"] - print(f"Dropping tables on database {conn.url.database}") + engine = models.HandlerSession.engine + print(f"Dropping tables on database {engine.url.database}") meta = sa.MetaData() - meta.reflect(bind=conn) - meta.drop_all(bind=conn) + meta.reflect(bind=engine) + meta.drop_all(bind=engine) def create_tables(retry=5, add=True): @@ -45,7 +45,6 @@ def create_tables(retry=5, add=True): tables. """ - conn = models.DBSession.session_factory.kw["bind"] tables = models.Base.metadata.sorted_tables if tables and not add: print("Existing tables found; not creating additional tables") @@ -53,9 +52,9 @@ def create_tables(retry=5, add=True): for i in range(1, retry + 1): try: - conn = models.DBSession.session_factory.kw["bind"] - print(f"Creating tables on database {conn.url.database}") - models.Base.metadata.create_all(conn) + engine = models.HandlerSession.engine + print(f"Creating tables on database {engine.url.database}") + models.Base.metadata.create_all(engine) table_list = ", ".join(list(models.Base.metadata.tables.keys())) print(f"Refreshed tables: {table_list}") diff --git a/app/models.py b/app/models.py index ce6bc5ff..bd693986 100644 --- a/app/models.py +++ b/app/models.py @@ -28,9 +28,24 @@ log_database = cfg.get("log.database", False) log_database_pool = cfg.get("log.database_pool", False) + +# This provides one session per *thread* +ThreadSession = scoped_session(sessionmaker()) + + +# This provides one session per *handler* +# It is recommended to use the handler session via +# self.Session, which has some knowledge +# of the current user. See `handlers/base.py` +# +# DBSession has been renamed to HandlerSession +# to make it clearer what it is doing. + +# We've renamed DBSession +# It is not recommended to use DBSession directly; session_context_id = contextvars.ContextVar("request_id", default=None) -# left here for backward compatibility: DBSession = scoped_session(sessionmaker(), scopefunc=session_context_id.get) +HandlerSession = DBSession class _VerifiedSession(sa.orm.session.Session): @@ -61,7 +76,7 @@ def __init__(self, user_or_token, **kwargs): or be generating an unverified session to only query the user with a certain id. Example: - with DBSession() as session: + with HandlerSession() as session: user = session.scalars( sa.select(User).where(User.id == user_id) ).first() @@ -159,7 +174,7 @@ def bulk_verify(mode, collection, accessor): ).subquery() inaccessible_row_ids = ( - DBSession() + HandlerSession() .scalars( sa.select(record_cls.id) .outerjoin( @@ -258,7 +273,7 @@ def init_db( "max_overflow": 10, "pool_recycle": 3600, } - conn = sa.create_engine( + engine = sa.create_engine( url, client_encoding="utf8", executemany_mode="values_plus_batch", @@ -268,10 +283,17 @@ def init_db( **{**default_engine_args, **engine_args}, ) - DBSession.configure(bind=conn, autoflush=autoflush, future=True) - Base.metadata.bind = conn + HandlerSession.configure(bind=engine, autoflush=autoflush, future=True) + # Convenience attribute to easily access the engine, otherwise would need + # HandlerSession.session_factory.kw["bind"] + HandlerSession.engine = engine + + ThreadSession.configure(bind=engine, autoflush=autoflush, future=True) + ThreadSession.engine = engine + + Base.metadata.bind = engine - return conn + return engine class SlugifiedStr(sa.types.TypeDecorator): @@ -478,8 +500,8 @@ def query_accessible_rows(self, cls, user_or_token, columns=None): """ # return only selected columns if requested if columns is not None: - return DBSession().query(*columns).select_from(cls) - return DBSession().query(cls) + return HandlerSession().query(*columns).select_from(cls) + return HandlerSession().query(cls) def select_accessible_rows(self, cls, user_or_token, columns=None): """Construct a Select object that, when executed, returns the rows of a @@ -571,9 +593,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None): # return only selected columns if requested if columns is not None: - query = DBSession().query(*columns).select_from(cls) + query = HandlerSession().query(*columns).select_from(cls) else: - query = DBSession().query(cls) + query = HandlerSession().query(cls) # traverse the relationship chain via sequential JOINs for relationship_name in self.relationship_names: @@ -735,9 +757,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None): # return only selected columns if requested if columns is None: - base = DBSession().query(cls) + base = HandlerSession().query(cls) else: - base = DBSession().query(*columns).select_from(cls) + base = HandlerSession().query(*columns).select_from(cls) # ensure the target class has all the relationships referred to # in this instance @@ -922,9 +944,9 @@ def query_accessible_rows(self, cls, user_or_token, columns=None): # retrieve specified columns if requested if columns is not None: - query = DBSession().query(*columns).select_from(cls) + query = HandlerSession().query(*columns).select_from(cls) else: - query = DBSession().query(cls) + query = HandlerSession().query(cls) # keep track of columns that will be null in the case of an unsuccessful # match for OR logic. @@ -1076,9 +1098,12 @@ def query_accessible_rows(self, cls, user_or_token, columns=None): # otherwise, all records are inaccessible if columns is not None: return ( - DBSession().query(*columns).select_from(cls).filter(sa.literal(False)) + HandlerSession() + .query(*columns) + .select_from(cls) + .filter(sa.literal(False)) ) - return DBSession().query(cls).filter(sa.literal(False)) + return HandlerSession().query(cls).filter(sa.literal(False)) def select_accessible_rows(self, cls, user_or_token, columns=None): """Construct a Select object that, when executed, returns the rows of a @@ -1148,7 +1173,7 @@ def __init__(self, query_or_query_generator): Query (SQLA 1.4): >>>> CustomUserAccessControl( - DBSession().query(Department).join(Employee).group_by( + HandlerSession().query(Department).join(Employee).group_by( Department.id ).having(sa.func.bool_and(Employee.is_manager.is_(True))) ) @@ -1166,8 +1191,8 @@ def __init__(self, query_or_query_generator): Query (SQLA 1.4): >>>> def access_logic(cls, user_or_token): ... if user_or_token.is_system_admin: - ... return DBSession().query(cls) - ... return DBSession().query(cls).join(Employee).group_by( + ... return HandlerSession().query(cls) + ... return HandlerSession().query(cls).join(Employee).group_by( ... cls.id ... ).having(sa.func.bool_and(Employee.is_manager.is_(True))) >>>> CustomUserAccessControl(access_logic) @@ -1303,7 +1328,7 @@ def is_accessible_by(self, user_or_token, mode="read"): # Query for the value of the access_func for this particular record and # return the result. - result = DBSession().execute(stmt).scalar_one() > 0 + result = HandlerSession().execute(stmt).scalar_one() > 0 if result is None: result = False @@ -1353,7 +1378,7 @@ def get_if_accessible_by( # TODO: vectorize this for pk in standardized: - instance = DBSession().query(cls).options(options).get(pk.item()) + instance = HandlerSession().query(cls).options(options).get(pk.item()) if instance is None or not instance.is_accessible_by( user_or_token, mode=mode ): @@ -1467,7 +1492,7 @@ def get( standardized = np.atleast_1d(id_or_list) result = [] - with DBSession() as session: + with HandlerSession() as session: # TODO: vectorize this for pk in standardized: if options: @@ -1519,7 +1544,7 @@ def get_all( If columns is specified, will return a list of tuples containing the data from each column requested. """ - with DBSession() as session: + with HandlerSession() as session: stmt = cls.select(user_or_token, mode, options, columns) values = session.scalars(stmt).all() @@ -1566,7 +1591,7 @@ def select( stmt = stmt.options(option) return stmt - query = DBSession.query_property() + query = HandlerSession.query_property() id = sa.Column( sa.Integer, @@ -1608,8 +1633,8 @@ def __repr__(self): def to_dict(self): """Serialize this object to a Python dictionary.""" if sa.inspection.inspect(self).expired: - self = DBSession().merge(self) - DBSession().refresh(self) + self = HandlerSession().merge(self) + HandlerSession().refresh(self) return {k: v for k, v in self.__dict__.items() if not k.startswith("_")} @classmethod @@ -1632,7 +1657,7 @@ def get_if_readable_by(cls, ident, user_or_token, options=[]): obj : baselayer.app.models.Base The requested entity. """ - obj = DBSession().query(cls).options(options).get(ident) + obj = HandlerSession().query(cls).options(options).get(ident) if obj is not None and not obj.is_readable_by(user_or_token): raise AccessError("Insufficient permissions.") @@ -1659,7 +1684,7 @@ def is_readable_by(self, user_or_token): def create_or_get(cls, id): """Return a new `cls` if an instance with the specified primary key does not exist, else return the existing instance.""" - obj = DBSession().query(cls).get(id) + obj = HandlerSession().query(cls).get(id) if obj is not None: return obj else: @@ -1876,7 +1901,7 @@ class User(Base): role_ids = association_proxy( "roles", "id", - creator=lambda r: DBSession().query(Role).get(r), + creator=lambda r: HandlerSession().query(Role).get(r), ) tokens = relationship( "Token", @@ -1982,7 +2007,7 @@ class Token(Base): lazy="selectin", ) acl_ids = association_proxy( - "acls", "id", creator=lambda acl: DBSession().query(ACL).get(acl) + "acls", "id", creator=lambda acl: HandlerSession().query(ACL).get(acl) ) permissions = acl_ids diff --git a/app/psa.py b/app/psa.py index f426ee1a..4681fa43 100644 --- a/app/psa.py +++ b/app/psa.py @@ -26,7 +26,7 @@ from sqlalchemy.types import PickleType, Text from tornado.template import Loader, Template -from baselayer.app.models import Base, DBSession, User +from baselayer.app.models import Base, HandlerSession, User from .env import load_env @@ -393,7 +393,7 @@ def _new_instance(cls, model, *args, **kwargs): @classmethod def _save_instance(cls, instance): instance_id = instance.id if hasattr(instance, "id") else None - session = DBSession() + session = HandlerSession() session.add(instance) if cls.COMMIT_SESSION: session.commit() @@ -404,7 +404,7 @@ def _save_instance(cls, instance): except AssertionError: session.commit() return ( - DBSession.query(instance.__class__).filter_by(id=instance_id).first() + HandlerSession.query(instance.__class__).filter_by(id=instance_id).first() if instance_id else instance ) @@ -439,9 +439,9 @@ def set_extra_data(self, extra_data=None): @classmethod def allowed_to_disconnect(cls, user, backend_name, association_id=None): if association_id is not None: - qs = DBSession().query(cls).filter(cls.id != association_id) + qs = HandlerSession().query(cls).filter(cls.id != association_id) else: - qs = DBSession().query(cls).filter(cls.provider != backend_name) + qs = HandlerSession().query(cls).filter(cls.provider != backend_name) qs = qs.filter(cls.user == user) if hasattr(user, "has_usable_password"): # TODO @@ -452,7 +452,7 @@ def allowed_to_disconnect(cls, user, backend_name, association_id=None): @classmethod def disconnect(cls, entry): - session = DBSession() + session = HandlerSession() session.delete(entry) try: session.flush() @@ -466,7 +466,8 @@ def user_exists(cls, *args, **kwargs): Arguments are directly passed to filter() manager method. """ return ( - DBSession().query(cls.user_model()).filter_by(*args, **kwargs).count() > 0 + HandlerSession().query(cls.user_model()).filter_by(*args, **kwargs).count() + > 0 ) @classmethod @@ -479,24 +480,29 @@ def create_user(cls, *args, **kwargs): @classmethod def get_user(cls, pk): - return DBSession().query(cls.user_model()).filter_by(id=pk).first() + return HandlerSession().query(cls.user_model()).filter_by(id=pk).first() @classmethod def get_users_by_email(cls, email): - return DBSession().query(cls.user_model()).filter_by(email=email).all() + return HandlerSession().query(cls.user_model()).filter_by(email=email).all() @classmethod def get_social_auth(cls, provider, uid): if not isinstance(uid, str): uid = str(uid) try: - return DBSession().query(cls).filter_by(provider=provider, uid=uid).first() + return ( + HandlerSession() + .query(cls) + .filter_by(provider=provider, uid=uid) + .first() + ) except IndexError: return None @classmethod def get_social_auth_for_user(cls, user, provider=None, id=None): - qs = DBSession().query(cls).filter_by(user_id=user.id) + qs = HandlerSession().query(cls).filter_by(user_id=user.id) if provider: qs = qs.filter_by(provider=provider) if id: @@ -522,7 +528,7 @@ class SQLAlchemyNonceMixin(SQLAlchemyMixin, NonceMixin): def use(cls, server_url, timestamp, salt): kwargs = {"server_url": server_url, "timestamp": timestamp, "salt": salt} - qs = DBSession().query(cls).filter_by(**kwargs).first() + qs = HandlerSession().query(cls).filter_by(**kwargs).first() if qs is None: qs = cls._new_instance(cls, **kwargs) return qs @@ -543,7 +549,7 @@ class SQLAlchemyAssociationMixin(SQLAlchemyMixin, AssociationMixin): def store(cls, server_url, association): # Don't use get_or_create because issued cannot be null assoc = ( - DBSession() + HandlerSession() .query(cls) .filter_by(server_url=server_url, handle=association.handle) .first() @@ -559,11 +565,11 @@ def store(cls, server_url, association): @classmethod def get(cls, *args, **kwargs): - return DBSession().query(cls).filter_by(*args, **kwargs).first() + return HandlerSession().query(cls).filter_by(*args, **kwargs).first() @classmethod def remove(cls, ids_to_delete): - with DBSession() as session: + with HandlerSession() as session: assocs = session.query(cls).filter(cls.id.in_(ids_to_delete)).all() for assoc in assocs: session.delete(assoc) @@ -579,7 +585,7 @@ class SQLAlchemyCodeMixin(SQLAlchemyMixin, CodeMixin): @classmethod def get_code(cls, code): - return DBSession().query(cls).filter_by(code=code).first() + return HandlerSession().query(cls).filter_by(code=code).first() class SQLAlchemyPartialMixin(SQLAlchemyMixin, PartialMixin): @@ -592,11 +598,11 @@ class SQLAlchemyPartialMixin(SQLAlchemyMixin, PartialMixin): @classmethod def load(cls, token): - return DBSession().query(cls).filter_by(token=token).first() + return HandlerSession().query(cls).filter_by(token=token).first() @classmethod def destroy(cls, token): - with DBSession() as session: + with HandlerSession() as session: partial = session.query(cls).filter_by(token=token).first() if partial: session.delete(partial) diff --git a/app/test_util.py b/app/test_util.py index 342c6f2b..4eff40b5 100644 --- a/app/test_util.py +++ b/app/test_util.py @@ -146,7 +146,7 @@ def click_css(self, css, timeout=10, scroll_parent=False): def driver(request): from selenium import webdriver - options = webdriver.firefox.options.Options() + options = webdriver.FirefoxOptions() if "BASELAYER_TEST_HEADLESS" in os.environ: options.headless = True options.set_preference("devtools.console.stdout.content", True) @@ -199,6 +199,6 @@ def login(driver): @pytest.fixture(scope="function", autouse=True) def reset_state(request): def teardown(): - models.DBSession().rollback() + models.HandlerSession().rollback() request.addfinalizer(teardown) diff --git a/doc/dev.md b/doc/dev.md index 2295d997..813d9fbb 100644 --- a/doc/dev.md +++ b/doc/dev.md @@ -31,7 +31,7 @@ make use of include: ``` from baselayer.app.env import load_env -from baselayer.models import DBSession, init_db +from baselayer.models import HandlerSession, init_db env, cfg = load_env() init_db(**cfg['database']) ``` @@ -39,13 +39,13 @@ init_db(**cfg['database']) - The session object controls various DB state operations: ``` -DBSession().add(obj) # add a new object into the DB -DBSession().commit() # commit modifications to objects -DBSession().rollback() # recover after a DB error +HandlerSession().add(obj) # add a new object into the DB +HandlerSession().commit() # commit modifications to objects +HandlerSession().rollback() # recover after a DB error ``` - Generic logic applicable to any model is included in the base model class `baselayer.app.models.Base` (`to_dict`, `__str__`, etc.), but can be overridden within a specific model -- Models can be selected directly (`User.query.all()`), or more specific queries can be constructed via the session object (`DBSession().query(User.id).all()`) +- Models can be selected directly (`User.query.all()`), or more specific queries can be constructed via the session object (`HandlerSession().query(User.id).all()`) - Convenience functionality: - Join relationships: some multi-step relationships are defined through joins using the `secondary` parameter to eliminate queries from the intermediate table; e.g., `User.acls` instad of `[r.acls for r in User.roles]` - [Association proxies](http://docs.sqlalchemy.org/en/latest/orm/extensions/associationproxy.html): shortcut to some attribute of a related object; e.g., `User.permissions` instead of `[a.id for a in User.acls]` @@ -59,7 +59,7 @@ DBSession().rollback() # recover after a DB error To start a session without verification (i.e., when not committing to DB): ``` -with DBSession() as session: +with HandlerSession() as session: ... ``` @@ -121,7 +121,11 @@ with VerifiedSession(user_or_token) as session: ``` If not using `commit()`, the call to `VerifiedSession(user_or_token)` -can be replaced with `DBSession()` with no arguments. +can be replaced with `HandlerSession()` with no arguments. + +When operating outside of a handler, such as when firing off new +tasks, or inside of services, `ThreadSession` must be used instead of +`HandlerSession`. ## Standards diff --git a/services/cron/cron.py b/services/cron/cron.py index 4ec197f7..9cfab5e5 100644 --- a/services/cron/cron.py +++ b/services/cron/cron.py @@ -8,7 +8,7 @@ from dateutil.parser import parse as parse_time from baselayer.app.env import load_env -from baselayer.app.models import CronJobRun, DBSession, init_db +from baselayer.app.models import CronJobRun, ThreadSession, init_db from baselayer.log import make_log log = make_log("cron") @@ -90,9 +90,11 @@ def cache_to_file(self): output, _ = proc.communicate() except Exception as e: log(f"Error executing {script}: {e}") - DBSession().add(CronJobRun(script=script, exit_status=1, output=str(e))) + ThreadSession().add( + CronJobRun(script=script, exit_status=1, output=str(e)) + ) else: - DBSession().add( + ThreadSession().add( CronJobRun( script=script, exit_status=proc.returncode, @@ -100,6 +102,6 @@ def cache_to_file(self): ) ) finally: - DBSession().commit() + ThreadSession().commit() time.sleep(60)