diff --git a/solvable/src/webservice/model.py b/solvable/src/webservice/model.py index 6f7f5ff..f480b03 100644 --- a/solvable/src/webservice/model.py +++ b/solvable/src/webservice/model.py @@ -1,14 +1,31 @@ from sqlalchemy import Column, Integer, String, create_engine from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import scoped_session, sessionmaker +from sqlalchemy.orm import sessionmaker from passlib.hash import pbkdf2_sha256 + engine = create_engine('sqlite:///db.db', convert_unicode=True) -db_session = scoped_session(sessionmaker(autocommit=False, - autoflush=False, - bind=engine)) + + +class Session: + session = None + + def __enter__(self): + self.session = Session.create() + return self.session + + @staticmethod + def create(): + factory = sessionmaker(autocommit=False, + autoflush=False, + bind=engine) + return factory() + + def __exit__(self, exc_type, exc_val, exc_tb): + self.session.close() + + Base = declarative_base() -Base.query = db_session.query_property() class User(Base): diff --git a/solvable/src/webservice/server.py b/solvable/src/webservice/server.py index cc8795c..d7d1361 100644 --- a/solvable/src/webservice/server.py +++ b/solvable/src/webservice/server.py @@ -2,7 +2,7 @@ from os import urandom, getenv from flask import Flask, render_template, request, session, url_for -from model import db_session, init_db, User, PasswordHasher +from model import init_db, User, Session, PasswordHasher BASEURL = getenv('BASEURL', '') init_db() @@ -15,19 +15,15 @@ def get_url(endpoint): app.jinja_env.globals.update(get_url=get_url) -@app.teardown_appcontext -def remove_db_session(exception=None): - db_session.remove() - - @app.route('/', methods=['GET', 'POST']) def index(): if request.method == 'POST': - user = User.query.filter(User.username == request.form['username']).first() + with Session() as db: + user = db.query(User).filter(User.username == request.form['username']).first() + + if not user or not PasswordHasher.verify(request.form['password'], user.passwordhash): + return render_template('login.html', alert='Invalid credentials!') - if not user or not PasswordHasher.verify(request.form['password'], user.passwordhash): - return render_template('login.html', alert='Invalid credentials!') - else: session['logged_in'] = True session['username'] = request.form['username'] return render_template('internal.html') @@ -40,14 +36,15 @@ def index(): @app.route('/register', methods=['GET', 'POST']) def register(): if request.method == 'POST': - validate_register_fields(request) + validate_register_fields(request.form.to_dict()) - if User.query.filter(User.username == request.form['username']).all(): - return render_template('register.html', alert='Username already in use.') + with Session() as db: + if db.query(User).filter(User.username == request.form['username']).all(): + return render_template('register.html', alert='Username already in use.') - db_session().add(User(username=request.form['username'], - passwordhash=PasswordHasher.hash(request.form['password']))) - db_session().commit() + db.add(User(username=request.form['username'], + passwordhash=PasswordHasher.hash(request.form['password']))) + db.commit() return render_template('login.html', success='Account "{}" successfully registered. You can log in now!'.format(request.form['username']))