1
0
mirror of https://github.com/avatao-content/test-tutorial-framework synced 2024-11-15 04:07:17 +00:00

Bootify webservice db session handling

This commit is contained in:
Kristóf Tóth 2018-08-31 16:57:00 +02:00
parent e7d78ed289
commit 4a1073e524
2 changed files with 34 additions and 25 deletions

View File

@ -13,18 +13,29 @@ session_factory = sessionmaker(
) )
@contextmanager class SessionWrapper:
def Session(factory=session_factory): def __init__(self):
session = factory() self._session_factory = session_factory
self._session_handle = None
@contextmanager
def session(self):
try: try:
yield session yield self._session
session.commit() self._session.commit()
except: except:
session.rollback() self._session.rollback()
raise raise
# session is closed by flask
# finally: @property
# session.close() def _session(self):
if self._session_handle is None:
self._session_handle = self._session_factory()
return self._session_handle
def teardown(self):
if self._session_handle is not None:
self._session_handle.close()
Base = declarative_base() Base = declarative_base()

View File

@ -1,9 +1,8 @@
from os import urandom, getenv from os import urandom, getenv
from functools import partial
from flask import Flask, render_template, request, session, url_for, g from flask import Flask, render_template, request, session, url_for, g
from model import init_db, session_factory, Session from model import init_db, SessionWrapper
from user_ops import UserOps from user_ops import UserOps
from errors import InvalidCredentialsError, UserExistsError from errors import InvalidCredentialsError, UserExistsError
@ -16,25 +15,24 @@ app.jinja_env.globals.update( # pylint: disable=no-member
) )
def get_db_session(): @app.before_request
if not hasattr(g, 'db_session'): def setup_db():
g.db_session = session_factory() # pylint: disable=protected-access
return g.db_session g._db_session_wrapper = SessionWrapper()
g.db_session = g._db_session_wrapper.session
Session = partial(Session, get_db_session)
@app.teardown_appcontext @app.teardown_appcontext
def close_db_session(err): # pylint: disable=unused-argument def close_db_session(_):
if hasattr(g, 'db_session'): # pylint: disable=protected-access
g.db_session.close() g._db_session_wrapper.teardown()
@app.route('/', methods=['GET', 'POST']) @app.route('/', methods=['GET', 'POST'])
def index(): def index():
if request.method == 'POST': if request.method == 'POST':
try: try:
with Session() as db_session: with g.db_session() as db_session:
UserOps( UserOps(
request.form.get('username'), request.form.get('username'),
request.form.get('password'), request.form.get('password'),
@ -67,7 +65,7 @@ def register():
return render_template('register.html', alert='Passwords do not match! Please try again.') return render_template('register.html', alert='Passwords do not match! Please try again.')
try: try:
with Session() as db_session: with g.db_session() as db_session:
UserOps( UserOps(
request.form.get('username'), request.form.get('username'),
request.form.get('password'), request.form.get('password'),