diff --git a/solvable/src/webservice/model.py b/solvable/src/webservice/model.py index d722cdc..b084b1e 100644 --- a/solvable/src/webservice/model.py +++ b/solvable/src/webservice/model.py @@ -13,18 +13,29 @@ session_factory = sessionmaker( ) -@contextmanager -def Session(factory=session_factory): - session = factory() - try: - yield session - session.commit() - except: - session.rollback() - raise - # session is closed by flask - # finally: - # session.close() +class SessionWrapper: + def __init__(self): + self._session_factory = session_factory + self._session_handle = None + + @contextmanager + def session(self): + try: + yield self._session + self._session.commit() + except: + self._session.rollback() + raise + + @property + def _session(self): + if self._session_handle is None: + self._session_handle = self._session_factory() + return self._session_handle + + def teardown(self): + if self._session_handle is not None: + self._session_handle.close() Base = declarative_base() diff --git a/solvable/src/webservice/server.py b/solvable/src/webservice/server.py index c359c61..1b36762 100644 --- a/solvable/src/webservice/server.py +++ b/solvable/src/webservice/server.py @@ -1,9 +1,8 @@ from os import urandom, getenv -from functools import partial from flask import Flask, render_template, request, session, url_for, g -from model import init_db, session_factory, Session +from model import init_db, SessionWrapper from user_ops import UserOps from errors import InvalidCredentialsError, UserExistsError @@ -16,25 +15,24 @@ app.jinja_env.globals.update( # pylint: disable=no-member ) -def get_db_session(): - if not hasattr(g, 'db_session'): - g.db_session = session_factory() - return g.db_session - -Session = partial(Session, get_db_session) +@app.before_request +def setup_db(): + # pylint: disable=protected-access + g._db_session_wrapper = SessionWrapper() + g.db_session = g._db_session_wrapper.session @app.teardown_appcontext -def close_db_session(err): # pylint: disable=unused-argument - if hasattr(g, 'db_session'): - g.db_session.close() +def close_db_session(_): + # pylint: disable=protected-access + g._db_session_wrapper.teardown() @app.route('/', methods=['GET', 'POST']) def index(): if request.method == 'POST': try: - with Session() as db_session: + with g.db_session() as db_session: UserOps( request.form.get('username'), request.form.get('password'), @@ -67,7 +65,7 @@ def register(): return render_template('register.html', alert='Passwords do not match! Please try again.') try: - with Session() as db_session: + with g.db_session() as db_session: UserOps( request.form.get('username'), request.form.get('password'),