Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Sessions to clarify their usage #366

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ jobs:

- name: Install Geckodriver / Selenium
run: |
GECKO_VER=0.30.0
GECKO_VER=0.33.0
CACHED_DOWNLOAD_DIR=~/.local/downloads
FILENAME=geckodriver-v${GECKO_VER}-linux64.tar.gz

Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ repos:
pass_filenames: true
exclude: baselayer|node_modules|static
- repo: https://github.com/pycqa/flake8
rev: 3.8.4
rev: 6.1.0
hooks:
- id: flake8
pass_filenames: true
Expand Down
4 changes: 2 additions & 2 deletions app/access.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from sqlalchemy.orm import joinedload

from baselayer.app.custom_exceptions import AccessError # noqa: F401
from baselayer.app.models import DBSession, Role, Token, User # noqa: F401
from baselayer.app.models import HandlerSession, Role, Token, User # noqa: F401


def auth_or_token(method):
Expand All @@ -26,7 +26,7 @@ def wrapper(self, *args, **kwargs):
token_header = self.request.headers.get("Authorization", None)
if token_header is not None and token_header.startswith("token "):
token_id = token_header.replace("token", "").strip()
with DBSession() as session:
with HandlerSession() as session:
token = session.scalars(
sa.select(Token)
.options(
Expand Down
30 changes: 18 additions & 12 deletions app/handlers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,13 @@
from ..env import load_env
from ..flow import Flow
from ..json_util import to_json
from ..models import DBSession, User, VerifiedSession, bulk_verify, session_context_id
from ..models import (
HandlerSession,
User,
VerifiedSession,
bulk_verify,
session_context_id,
)

env, cfg = load_env()
log = make_log("basehandler")
Expand All @@ -49,7 +55,7 @@ def get_current_user(self):
user_id = int(self.user_id())
oauth_uid = self.get_secure_cookie("user_oauth_uid")
if user_id and oauth_uid:
with DBSession() as session:
with HandlerSession() as session:
try:
user = session.scalars(
sqlalchemy.select(User).where(User.id == user_id)
Expand All @@ -74,7 +80,7 @@ def get_current_user(self):
return None

def login_user(self, user):
with DBSession() as session:
with HandlerSession() as session:
try:
self.set_secure_cookie("user_id", str(user.id))
user = session.scalars(
Expand Down Expand Up @@ -120,7 +126,7 @@ def log_exception(self, typ=None, value=None, tb=None):
)

def on_finish(self):
DBSession.remove()
HandlerSession.remove()


class BaseHandler(PSABaseHandler):
Expand Down Expand Up @@ -153,7 +159,7 @@ def Session(self):
# must merge the user object with the current session
# ref: https://docs.sqlalchemy.org/en/14/orm/session_basics.html#adding-new-or-existing-items
session.add(self.current_user)
session.bind = DBSession.session_factory.kw["bind"]
session.bind = HandlerSession.engine
yield session

def verify_permissions(self):
Expand All @@ -164,20 +170,20 @@ def verify_permissions(self):
"""

# get items to be inserted
new_rows = [row for row in DBSession().new]
new_rows = [row for row in HandlerSession().new]

# get items to be updated
updated_rows = [
row for row in DBSession().dirty if DBSession().is_modified(row)
row for row in HandlerSession().dirty if HandlerSession().is_modified(row)
]

# get items to be deleted
deleted_rows = [row for row in DBSession().deleted]
deleted_rows = [row for row in HandlerSession().deleted]

# get items that were read
read_rows = [
row
for row in set(DBSession().identity_map.values())
for row in set(HandlerSession().identity_map.values())
- (set(updated_rows) | set(new_rows) | set(deleted_rows))
]

Expand All @@ -194,15 +200,15 @@ def verify_permissions(self):
# update transaction state in DB, but don't commit yet. this updates
# or adds rows in the database and uses their new state in joins,
# for permissions checking purposes.
DBSession().flush()
HandlerSession().flush()
bulk_verify("create", new_rows, self.current_user)

def verify_and_commit(self):
"""Verify permissions on the current database session and commit if
successful, otherwise raise an AccessError.
"""
self.verify_permissions()
DBSession().commit()
HandlerSession().commit()

def prepare(self):
self.cfg = self.application.cfg
Expand All @@ -225,7 +231,7 @@ def prepare(self):
N = 5
for i in range(1, N + 1):
try:
assert DBSession.session_factory.kw["bind"] is not None
assert HandlerSession.engine is not None
except Exception as e:
if i == N:
raise e
Expand Down
17 changes: 8 additions & 9 deletions app/model_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ def status(message):
else:
print(f"\r[✓] {message}")
finally:
models.DBSession().commit()
models.HandlerSession().commit()


def drop_tables():
conn = models.DBSession.session_factory.kw["bind"]
print(f"Dropping tables on database {conn.url.database}")
engine = models.HandlerSession.engine
print(f"Dropping tables on database {engine.url.database}")
meta = sa.MetaData()
meta.reflect(bind=conn)
meta.drop_all(bind=conn)
meta.reflect(bind=engine)
meta.drop_all(bind=engine)


def create_tables(retry=5, add=True):
Expand All @@ -45,17 +45,16 @@ def create_tables(retry=5, add=True):
tables.

"""
conn = models.DBSession.session_factory.kw["bind"]
tables = models.Base.metadata.sorted_tables
if tables and not add:
print("Existing tables found; not creating additional tables")
return

for i in range(1, retry + 1):
try:
conn = models.DBSession.session_factory.kw["bind"]
print(f"Creating tables on database {conn.url.database}")
models.Base.metadata.create_all(conn)
engine = models.HandlerSession.engine
print(f"Creating tables on database {engine.url.database}")
models.Base.metadata.create_all(engine)

table_list = ", ".join(list(models.Base.metadata.tables.keys()))
print(f"Refreshed tables: {table_list}")
Expand Down
Loading
Loading