mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
[FastAPI] Refactor db modifications
For SQLAlchemy to automatically understand updates from the external world, it must use an `autocommit=True` in its session. This change breaks how we were using commit previously, as `autocommit=True` causes SQLAlchemy to commit when a SessionTransaction context hits __exit__. So, a refactoring was required of our tests: All usage of any `db.{create,delete}` must be called **within** a SessionTransaction context, created via new `db.begin()`. From this point forward, we're going to require: ``` with db.begin(): db.create(...) db.delete(...) db.session.delete(object) ``` With this, we now get external DB modifications automatically without reloading or restarting the FastAPI server, which we absolutely need for production. Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
b52059d437
commit
a5943bf2ad
37 changed files with 998 additions and 902 deletions
44
aurweb/db.py
44
aurweb/db.py
|
@ -59,20 +59,15 @@ def query(model, *args, **kwargs):
|
|||
return session.query(model).filter(*args, **kwargs)
|
||||
|
||||
|
||||
def create(model, autocommit: bool = True, *args, **kwargs):
|
||||
def create(model, *args, **kwargs):
|
||||
instance = model(*args, **kwargs)
|
||||
add(instance)
|
||||
if autocommit is True:
|
||||
commit()
|
||||
return instance
|
||||
return add(instance)
|
||||
|
||||
|
||||
def delete(model, *args, autocommit: bool = True, **kwargs):
|
||||
def delete(model, *args, **kwargs):
|
||||
instance = session.query(model).filter(*args, **kwargs)
|
||||
for record in instance:
|
||||
session.delete(record)
|
||||
if autocommit is True:
|
||||
commit()
|
||||
|
||||
|
||||
def rollback():
|
||||
|
@ -84,8 +79,25 @@ def add(model):
|
|||
return model
|
||||
|
||||
|
||||
def commit():
|
||||
session.commit()
|
||||
def begin():
|
||||
""" Begin an SQLAlchemy SessionTransaction.
|
||||
|
||||
This context is **required** to perform an modifications to the
|
||||
database.
|
||||
|
||||
Example:
|
||||
|
||||
with db.begin():
|
||||
object = db.create(...)
|
||||
# On __exit__, db.commit() is run.
|
||||
|
||||
with db.begin():
|
||||
object = db.delete(...)
|
||||
# On __exit__, db.commit() is run.
|
||||
|
||||
:return: A new SessionTransaction based on session
|
||||
"""
|
||||
return session.begin()
|
||||
|
||||
|
||||
def get_sqlalchemy_url():
|
||||
|
@ -155,23 +167,23 @@ def get_engine(echo: bool = False):
|
|||
connect_args=connect_args,
|
||||
echo=echo)
|
||||
|
||||
Session = sessionmaker(autocommit=True, autoflush=False, bind=engine)
|
||||
session = Session()
|
||||
|
||||
if db_backend == "sqlite":
|
||||
# For SQLite, we need to add some custom functions as
|
||||
# they are used in the reference graph method.
|
||||
def regexp(regex, item):
|
||||
return bool(re.search(regex, str(item)))
|
||||
|
||||
@event.listens_for(engine, "begin")
|
||||
def do_begin(conn):
|
||||
@event.listens_for(engine, "connect")
|
||||
def do_begin(conn, record):
|
||||
create_deterministic_function = functools.partial(
|
||||
conn.connection.create_function,
|
||||
conn.create_function,
|
||||
deterministic=True
|
||||
)
|
||||
create_deterministic_function("REGEXP", 2, regexp)
|
||||
|
||||
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
session = Session()
|
||||
|
||||
return engine
|
||||
|
||||
|
||||
|
|
|
@ -102,7 +102,7 @@ class User(Base):
|
|||
def login(self, request: Request, password: str, session_time=0):
|
||||
""" Login and authenticate a request. """
|
||||
|
||||
from aurweb.db import session
|
||||
from aurweb import db
|
||||
from aurweb.models.session import Session, generate_unique_sid
|
||||
|
||||
if not self._login_approved(request):
|
||||
|
@ -112,10 +112,7 @@ class User(Base):
|
|||
if not self.authenticated:
|
||||
return None
|
||||
|
||||
self.LastLogin = now_ts = datetime.utcnow().timestamp()
|
||||
self.LastLoginIPAddress = request.client.host
|
||||
session.commit()
|
||||
|
||||
now_ts = datetime.utcnow().timestamp()
|
||||
session_ts = now_ts + (
|
||||
session_time if session_time
|
||||
else aurweb.config.getint("options", "login_timeout")
|
||||
|
@ -123,22 +120,23 @@ class User(Base):
|
|||
|
||||
sid = None
|
||||
|
||||
if not self.session:
|
||||
sid = generate_unique_sid()
|
||||
self.session = Session(UsersID=self.ID, SessionID=sid,
|
||||
LastUpdateTS=session_ts)
|
||||
session.add(self.session)
|
||||
else:
|
||||
last_updated = self.session.LastUpdateTS
|
||||
if last_updated and last_updated < now_ts:
|
||||
self.session.SessionID = sid = generate_unique_sid()
|
||||
with db.begin():
|
||||
self.LastLogin = now_ts
|
||||
self.LastLoginIPAddress = request.client.host
|
||||
if not self.session:
|
||||
sid = generate_unique_sid()
|
||||
self.session = Session(UsersID=self.ID, SessionID=sid,
|
||||
LastUpdateTS=session_ts)
|
||||
db.add(self.session)
|
||||
else:
|
||||
# Session is still valid; retrieve the current SID.
|
||||
sid = self.session.SessionID
|
||||
last_updated = self.session.LastUpdateTS
|
||||
if last_updated and last_updated < now_ts:
|
||||
self.session.SessionID = sid = generate_unique_sid()
|
||||
else:
|
||||
# Session is still valid; retrieve the current SID.
|
||||
sid = self.session.SessionID
|
||||
|
||||
self.session.LastUpdateTS = session_ts
|
||||
|
||||
session.commit()
|
||||
self.session.LastUpdateTS = session_ts
|
||||
|
||||
request.cookies["AURSID"] = self.session.SessionID
|
||||
return self.session.SessionID
|
||||
|
@ -149,13 +147,11 @@ class User(Base):
|
|||
return aurweb.auth.has_credential(self, cred, approved)
|
||||
|
||||
def logout(self, request):
|
||||
from aurweb.db import session
|
||||
|
||||
del request.cookies["AURSID"]
|
||||
self.authenticated = False
|
||||
if self.session:
|
||||
session.delete(self.session)
|
||||
session.commit()
|
||||
with db.begin():
|
||||
db.session.delete(self.session)
|
||||
|
||||
def is_trusted_user(self):
|
||||
return self.AccountType.ID in {
|
||||
|
|
|
@ -43,8 +43,6 @@ async def passreset_post(request: Request,
|
|||
resetkey: str = Form(default=None),
|
||||
password: str = Form(default=None),
|
||||
confirm: str = Form(default=None)):
|
||||
from aurweb.db import session
|
||||
|
||||
context = await make_variable_context(request, "Password Reset")
|
||||
|
||||
# The user parameter being required, we can match against
|
||||
|
@ -86,12 +84,11 @@ async def passreset_post(request: Request,
|
|||
|
||||
# We got to this point; everything matched up. Update the password
|
||||
# and remove the ResetKey.
|
||||
user.ResetKey = str()
|
||||
user.update_password(password)
|
||||
|
||||
if user.session:
|
||||
session.delete(user.session)
|
||||
session.commit()
|
||||
with db.begin():
|
||||
user.ResetKey = str()
|
||||
if user.session:
|
||||
db.session.delete(user.session)
|
||||
user.update_password(password)
|
||||
|
||||
# Render ?step=complete.
|
||||
return RedirectResponse(url="/passreset?step=complete",
|
||||
|
@ -99,8 +96,8 @@ async def passreset_post(request: Request,
|
|||
|
||||
# If we got here, we continue with issuing a resetkey for the user.
|
||||
resetkey = db.make_random_value(User, User.ResetKey)
|
||||
user.ResetKey = resetkey
|
||||
session.commit()
|
||||
with db.begin():
|
||||
user.ResetKey = resetkey
|
||||
|
||||
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
|
||||
ResetKeyNotification(executor, user.ID).send()
|
||||
|
@ -364,8 +361,6 @@ async def account_register_post(request: Request,
|
|||
ON: bool = Form(default=False),
|
||||
captcha: str = Form(default=None),
|
||||
captcha_salt: str = Form(...)):
|
||||
from aurweb.db import session
|
||||
|
||||
context = await make_variable_context(request, "Register")
|
||||
|
||||
args = dict(await request.form())
|
||||
|
@ -394,11 +389,13 @@ async def account_register_post(request: Request,
|
|||
AccountType.AccountType == "User").first()
|
||||
|
||||
# Create a user given all parameters available.
|
||||
user = db.create(User, Username=U, Email=E, HideEmail=H, BackupEmail=BE,
|
||||
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
|
||||
LangPreference=L, Timezone=TZ, CommentNotify=CN,
|
||||
UpdateNotify=UN, OwnershipNotify=ON, ResetKey=resetkey,
|
||||
AccountType=account_type)
|
||||
with db.begin():
|
||||
user = db.create(User, Username=U,
|
||||
Email=E, HideEmail=H, BackupEmail=BE,
|
||||
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
|
||||
LangPreference=L, Timezone=TZ, CommentNotify=CN,
|
||||
UpdateNotify=UN, OwnershipNotify=ON,
|
||||
ResetKey=resetkey, AccountType=account_type)
|
||||
|
||||
# If a PK was given and either one does not exist or the given
|
||||
# PK mismatches the existing user's SSHPubKey.PubKey.
|
||||
|
@ -410,10 +407,10 @@ async def account_register_post(request: Request,
|
|||
# Remove the host part.
|
||||
pubkey = parts[0] + " " + parts[1]
|
||||
fingerprint = get_fingerprint(pubkey)
|
||||
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
|
||||
PubKey=pubkey,
|
||||
Fingerprint=fingerprint)
|
||||
session.commit()
|
||||
with db.begin():
|
||||
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
|
||||
PubKey=pubkey,
|
||||
Fingerprint=fingerprint)
|
||||
|
||||
# Send a reset key notification to the new user.
|
||||
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
|
||||
|
@ -499,63 +496,67 @@ async def account_edit_post(request: Request,
|
|||
status_code=int(HTTPStatus.BAD_REQUEST))
|
||||
|
||||
# Set all updated fields as needed.
|
||||
user.Username = U or user.Username
|
||||
user.Email = E or user.Email
|
||||
user.HideEmail = bool(H)
|
||||
user.BackupEmail = BE or user.BackupEmail
|
||||
user.RealName = R or user.RealName
|
||||
user.Homepage = HP or user.Homepage
|
||||
user.IRCNick = I or user.IRCNick
|
||||
user.PGPKey = K or user.PGPKey
|
||||
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
|
||||
with db.begin():
|
||||
user.Username = U or user.Username
|
||||
user.Email = E or user.Email
|
||||
user.HideEmail = bool(H)
|
||||
user.BackupEmail = BE or user.BackupEmail
|
||||
user.RealName = R or user.RealName
|
||||
user.Homepage = HP or user.Homepage
|
||||
user.IRCNick = I or user.IRCNick
|
||||
user.PGPKey = K or user.PGPKey
|
||||
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
|
||||
|
||||
# If we update the language, update the cookie as well.
|
||||
if L and L != user.LangPreference:
|
||||
request.cookies["AURLANG"] = L
|
||||
user.LangPreference = L
|
||||
with db.begin():
|
||||
user.LangPreference = L
|
||||
context["language"] = L
|
||||
|
||||
# If we update the timezone, also update the cookie.
|
||||
if TZ and TZ != user.Timezone:
|
||||
user.Timezone = TZ
|
||||
with db.begin():
|
||||
user.Timezone = TZ
|
||||
request.cookies["AURTZ"] = TZ
|
||||
context["timezone"] = TZ
|
||||
|
||||
user.CommentNotify = bool(CN)
|
||||
user.UpdateNotify = bool(UN)
|
||||
user.OwnershipNotify = bool(ON)
|
||||
with db.begin():
|
||||
user.CommentNotify = bool(CN)
|
||||
user.UpdateNotify = bool(UN)
|
||||
user.OwnershipNotify = bool(ON)
|
||||
|
||||
# If a PK is given, compare it against the target user's PK.
|
||||
if PK:
|
||||
# Get the second token in the public key, which is the actual key.
|
||||
pubkey = PK.strip().rstrip()
|
||||
parts = pubkey.split(" ")
|
||||
if len(parts) == 3:
|
||||
# Remove the host part.
|
||||
pubkey = parts[0] + " " + parts[1]
|
||||
fingerprint = get_fingerprint(pubkey)
|
||||
if not user.ssh_pub_key:
|
||||
# No public key exists, create one.
|
||||
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
|
||||
PubKey=pubkey,
|
||||
Fingerprint=fingerprint)
|
||||
elif user.ssh_pub_key.PubKey != pubkey:
|
||||
# A public key already exists, update it.
|
||||
user.ssh_pub_key.PubKey = pubkey
|
||||
user.ssh_pub_key.Fingerprint = fingerprint
|
||||
elif user.ssh_pub_key:
|
||||
# Else, if the user has a public key already, delete it.
|
||||
session.delete(user.ssh_pub_key)
|
||||
|
||||
# Commit changes, if any.
|
||||
session.commit()
|
||||
with db.begin():
|
||||
if PK:
|
||||
# Get the second token in the public key, which is the actual key.
|
||||
pubkey = PK.strip().rstrip()
|
||||
parts = pubkey.split(" ")
|
||||
if len(parts) == 3:
|
||||
# Remove the host part.
|
||||
pubkey = parts[0] + " " + parts[1]
|
||||
fingerprint = get_fingerprint(pubkey)
|
||||
if not user.ssh_pub_key:
|
||||
# No public key exists, create one.
|
||||
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
|
||||
PubKey=pubkey,
|
||||
Fingerprint=fingerprint)
|
||||
elif user.ssh_pub_key.PubKey != pubkey:
|
||||
# A public key already exists, update it.
|
||||
user.ssh_pub_key.PubKey = pubkey
|
||||
user.ssh_pub_key.Fingerprint = fingerprint
|
||||
elif user.ssh_pub_key:
|
||||
# Else, if the user has a public key already, delete it.
|
||||
session.delete(user.ssh_pub_key)
|
||||
|
||||
if P and not user.valid_password(P):
|
||||
# Remove the fields we consumed for passwords.
|
||||
context["P"] = context["C"] = str()
|
||||
|
||||
# If a password was given and it doesn't match the user's, update it.
|
||||
user.update_password(P)
|
||||
with db.begin():
|
||||
user.update_password(P)
|
||||
|
||||
if user == request.user:
|
||||
# If the target user is the request user, login with
|
||||
# the updated password and update AURSID.
|
||||
|
@ -731,21 +732,17 @@ async def terms_of_service_post(request: Request,
|
|||
accept_needed = sorted(unaccepted + diffs)
|
||||
return render_terms_of_service(request, context, accept_needed)
|
||||
|
||||
# For each term we found, query for the matching accepted term
|
||||
# and update its Revision to the term's current Revision.
|
||||
for term in diffs:
|
||||
accepted_term = request.user.accepted_terms.filter(
|
||||
AcceptedTerm.TermsID == term.ID).first()
|
||||
accepted_term.Revision = term.Revision
|
||||
with db.begin():
|
||||
# For each term we found, query for the matching accepted term
|
||||
# and update its Revision to the term's current Revision.
|
||||
for term in diffs:
|
||||
accepted_term = request.user.accepted_terms.filter(
|
||||
AcceptedTerm.TermsID == term.ID).first()
|
||||
accepted_term.Revision = term.Revision
|
||||
|
||||
# For each term that was never accepted, accept it!
|
||||
for term in unaccepted:
|
||||
db.create(AcceptedTerm, User=request.user,
|
||||
Term=term, Revision=term.Revision,
|
||||
autocommit=False)
|
||||
|
||||
if diffs or unaccepted:
|
||||
# If we had any terms to update, commit the changes.
|
||||
db.commit()
|
||||
# For each term that was never accepted, accept it!
|
||||
for term in unaccepted:
|
||||
db.create(AcceptedTerm, User=request.user,
|
||||
Term=term, Revision=term.Revision)
|
||||
|
||||
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))
|
||||
|
|
|
@ -44,8 +44,6 @@ async def language(request: Request,
|
|||
setting the language on any page, we want to preserve query
|
||||
parameters across the redirect.
|
||||
"""
|
||||
from aurweb.db import session
|
||||
|
||||
if next[0] != '/':
|
||||
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
|
||||
|
||||
|
@ -53,8 +51,8 @@ async def language(request: Request,
|
|||
|
||||
# If the user is authenticated, update the user's LangPreference.
|
||||
if request.user.is_authenticated():
|
||||
request.user.LangPreference = set_lang
|
||||
session.commit()
|
||||
with db.begin():
|
||||
request.user.LangPreference = set_lang
|
||||
|
||||
# In any case, set the response's AURLANG cookie that never expires.
|
||||
response = RedirectResponse(url=f"{next}{query_string}",
|
||||
|
|
|
@ -214,10 +214,9 @@ async def trusted_user_proposal_post(request: Request,
|
|||
return Response("Invalid 'decision' value.",
|
||||
status_code=int(HTTPStatus.BAD_REQUEST))
|
||||
|
||||
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo,
|
||||
autocommit=False)
|
||||
voteinfo.ActiveTUs += 1
|
||||
db.commit()
|
||||
with db.begin():
|
||||
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo)
|
||||
voteinfo.ActiveTUs += 1
|
||||
|
||||
context["error"] = "You've already voted for this proposal."
|
||||
return render_proposal(request, context, proposal, voteinfo, voters, vote)
|
||||
|
@ -294,12 +293,13 @@ async def trusted_user_addvote_post(request: Request,
|
|||
agenda = re.sub(r'<[/]?style.*>', '', agenda)
|
||||
|
||||
# Create a new TUVoteInfo (proposal)!
|
||||
voteinfo = db.create(TUVoteInfo,
|
||||
User=user,
|
||||
Agenda=agenda,
|
||||
Submitted=timestamp, End=timestamp + duration,
|
||||
Quorum=quorum,
|
||||
Submitter=request.user)
|
||||
with db.begin():
|
||||
voteinfo = db.create(TUVoteInfo,
|
||||
User=user,
|
||||
Agenda=agenda,
|
||||
Submitted=timestamp, End=timestamp + duration,
|
||||
Quorum=quorum,
|
||||
Submitter=request.user)
|
||||
|
||||
# Redirect to the new proposal.
|
||||
return RedirectResponse(f"/tu/{voteinfo.ID}",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue