mirror of
				https://github.com/avatao-content/test-tutorial-framework
				synced 2025-11-04 05:32:55 +00:00 
			
		
		
		
	Refactor messy global scoped_session from webservice
This commit is contained in:
		@@ -1,14 +1,31 @@
 | 
				
			|||||||
from sqlalchemy import Column, Integer, String, create_engine
 | 
					from sqlalchemy import Column, Integer, String, create_engine
 | 
				
			||||||
from sqlalchemy.ext.declarative import declarative_base
 | 
					from sqlalchemy.ext.declarative import declarative_base
 | 
				
			||||||
from sqlalchemy.orm import scoped_session, sessionmaker
 | 
					from sqlalchemy.orm import sessionmaker
 | 
				
			||||||
from passlib.hash import pbkdf2_sha256
 | 
					from passlib.hash import pbkdf2_sha256
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
engine = create_engine('sqlite:///db.db', convert_unicode=True)
 | 
					engine = create_engine('sqlite:///db.db', convert_unicode=True)
 | 
				
			||||||
db_session = scoped_session(sessionmaker(autocommit=False,
 | 
					
 | 
				
			||||||
                                         autoflush=False,
 | 
					
 | 
				
			||||||
                                         bind=engine))
 | 
					class Session:
 | 
				
			||||||
 | 
					    session = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __enter__(self):
 | 
				
			||||||
 | 
					        self.session = Session.create()
 | 
				
			||||||
 | 
					        return self.session
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @staticmethod
 | 
				
			||||||
 | 
					    def create():
 | 
				
			||||||
 | 
					        factory = sessionmaker(autocommit=False,
 | 
				
			||||||
 | 
					                               autoflush=False,
 | 
				
			||||||
 | 
					                               bind=engine)
 | 
				
			||||||
 | 
					        return factory()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __exit__(self, exc_type, exc_val, exc_tb):
 | 
				
			||||||
 | 
					        self.session.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
Base = declarative_base()
 | 
					Base = declarative_base()
 | 
				
			||||||
Base.query = db_session.query_property()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class User(Base):
 | 
					class User(Base):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -2,7 +2,7 @@ from os import urandom, getenv
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
from flask import Flask, render_template, request, session, url_for
 | 
					from flask import Flask, render_template, request, session, url_for
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from model import db_session, init_db, User, PasswordHasher
 | 
					from model import init_db, User, Session, PasswordHasher
 | 
				
			||||||
 | 
					
 | 
				
			||||||
BASEURL = getenv('BASEURL', '')
 | 
					BASEURL = getenv('BASEURL', '')
 | 
				
			||||||
init_db()
 | 
					init_db()
 | 
				
			||||||
@@ -15,19 +15,15 @@ def get_url(endpoint):
 | 
				
			|||||||
app.jinja_env.globals.update(get_url=get_url)
 | 
					app.jinja_env.globals.update(get_url=get_url)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@app.teardown_appcontext
 | 
					 | 
				
			||||||
def remove_db_session(exception=None):
 | 
					 | 
				
			||||||
    db_session.remove()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
@app.route('/', methods=['GET', 'POST'])
 | 
					@app.route('/', methods=['GET', 'POST'])
 | 
				
			||||||
def index():
 | 
					def index():
 | 
				
			||||||
    if request.method == 'POST':
 | 
					    if request.method == 'POST':
 | 
				
			||||||
        user = User.query.filter(User.username == request.form['username']).first()
 | 
					        with Session() as db:
 | 
				
			||||||
 | 
					            user = db.query(User).filter(User.username == request.form['username']).first()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if not user or not PasswordHasher.verify(request.form['password'], user.passwordhash):
 | 
				
			||||||
 | 
					                return render_template('login.html', alert='Invalid credentials!')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not user or not PasswordHasher.verify(request.form['password'], user.passwordhash):
 | 
					 | 
				
			||||||
            return render_template('login.html', alert='Invalid credentials!')
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            session['logged_in'] = True
 | 
					            session['logged_in'] = True
 | 
				
			||||||
            session['username'] = request.form['username']
 | 
					            session['username'] = request.form['username']
 | 
				
			||||||
            return render_template('internal.html')
 | 
					            return render_template('internal.html')
 | 
				
			||||||
@@ -40,14 +36,15 @@ def index():
 | 
				
			|||||||
@app.route('/register', methods=['GET', 'POST'])
 | 
					@app.route('/register', methods=['GET', 'POST'])
 | 
				
			||||||
def register():
 | 
					def register():
 | 
				
			||||||
    if request.method == 'POST':
 | 
					    if request.method == 'POST':
 | 
				
			||||||
        validate_register_fields(request)
 | 
					        validate_register_fields(request.form.to_dict())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if User.query.filter(User.username == request.form['username']).all():
 | 
					        with Session() as db:
 | 
				
			||||||
            return render_template('register.html', alert='Username already in use.')
 | 
					            if db.query(User).filter(User.username == request.form['username']).all():
 | 
				
			||||||
 | 
					                return render_template('register.html', alert='Username already in use.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        db_session().add(User(username=request.form['username'],
 | 
					            db.add(User(username=request.form['username'],
 | 
				
			||||||
                              passwordhash=PasswordHasher.hash(request.form['password'])))
 | 
					                        passwordhash=PasswordHasher.hash(request.form['password'])))
 | 
				
			||||||
        db_session().commit()
 | 
					            db.commit()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return render_template('login.html', success='Account "{}" successfully registered. You can log in now!'.format(request.form['username']))
 | 
					        return render_template('login.html', success='Account "{}" successfully registered. You can log in now!'.format(request.form['username']))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user