[FastAPI] Refactor db modifications

For SQLAlchemy to automatically understand updates from the
external world, it must use an `autocommit=True` in its session.

This change breaks how we were using commit previously, as
`autocommit=True` causes SQLAlchemy to commit when a
SessionTransaction context hits __exit__.

So, a refactoring was required of our tests: All usage of
any `db.{create,delete}` must be called **within** a
SessionTransaction context, created via new `db.begin()`.

From this point forward, we're going to require:

```
with db.begin():
    db.create(...)
    db.delete(...)
    db.session.delete(object)
```

With this, we now get external DB modifications automatically
without reloading or restarting the FastAPI server, which we
absolutely need for production.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-09-02 16:26:48 -07:00
parent b52059d437
commit a5943bf2ad
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
37 changed files with 998 additions and 902 deletions

View file

@ -59,20 +59,15 @@ def query(model, *args, **kwargs):
return session.query(model).filter(*args, **kwargs)
def create(model, autocommit: bool = True, *args, **kwargs):
def create(model, *args, **kwargs):
instance = model(*args, **kwargs)
add(instance)
if autocommit is True:
commit()
return instance
return add(instance)
def delete(model, *args, autocommit: bool = True, **kwargs):
def delete(model, *args, **kwargs):
instance = session.query(model).filter(*args, **kwargs)
for record in instance:
session.delete(record)
if autocommit is True:
commit()
def rollback():
@ -84,8 +79,25 @@ def add(model):
return model
def commit():
session.commit()
def begin():
""" Begin an SQLAlchemy SessionTransaction.
This context is **required** to perform an modifications to the
database.
Example:
with db.begin():
object = db.create(...)
# On __exit__, db.commit() is run.
with db.begin():
object = db.delete(...)
# On __exit__, db.commit() is run.
:return: A new SessionTransaction based on session
"""
return session.begin()
def get_sqlalchemy_url():
@ -155,23 +167,23 @@ def get_engine(echo: bool = False):
connect_args=connect_args,
echo=echo)
Session = sessionmaker(autocommit=True, autoflush=False, bind=engine)
session = Session()
if db_backend == "sqlite":
# For SQLite, we need to add some custom functions as
# they are used in the reference graph method.
def regexp(regex, item):
return bool(re.search(regex, str(item)))
@event.listens_for(engine, "begin")
def do_begin(conn):
@event.listens_for(engine, "connect")
def do_begin(conn, record):
create_deterministic_function = functools.partial(
conn.connection.create_function,
conn.create_function,
deterministic=True
)
create_deterministic_function("REGEXP", 2, regexp)
Session = sessionmaker(autocommit=False, autoflush=False, bind=engine)
session = Session()
return engine

View file

@ -102,7 +102,7 @@ class User(Base):
def login(self, request: Request, password: str, session_time=0):
""" Login and authenticate a request. """
from aurweb.db import session
from aurweb import db
from aurweb.models.session import Session, generate_unique_sid
if not self._login_approved(request):
@ -112,10 +112,7 @@ class User(Base):
if not self.authenticated:
return None
self.LastLogin = now_ts = datetime.utcnow().timestamp()
self.LastLoginIPAddress = request.client.host
session.commit()
now_ts = datetime.utcnow().timestamp()
session_ts = now_ts + (
session_time if session_time
else aurweb.config.getint("options", "login_timeout")
@ -123,22 +120,23 @@ class User(Base):
sid = None
if not self.session:
sid = generate_unique_sid()
self.session = Session(UsersID=self.ID, SessionID=sid,
LastUpdateTS=session_ts)
session.add(self.session)
else:
last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts:
self.session.SessionID = sid = generate_unique_sid()
with db.begin():
self.LastLogin = now_ts
self.LastLoginIPAddress = request.client.host
if not self.session:
sid = generate_unique_sid()
self.session = Session(UsersID=self.ID, SessionID=sid,
LastUpdateTS=session_ts)
db.add(self.session)
else:
# Session is still valid; retrieve the current SID.
sid = self.session.SessionID
last_updated = self.session.LastUpdateTS
if last_updated and last_updated < now_ts:
self.session.SessionID = sid = generate_unique_sid()
else:
# Session is still valid; retrieve the current SID.
sid = self.session.SessionID
self.session.LastUpdateTS = session_ts
session.commit()
self.session.LastUpdateTS = session_ts
request.cookies["AURSID"] = self.session.SessionID
return self.session.SessionID
@ -149,13 +147,11 @@ class User(Base):
return aurweb.auth.has_credential(self, cred, approved)
def logout(self, request):
from aurweb.db import session
del request.cookies["AURSID"]
self.authenticated = False
if self.session:
session.delete(self.session)
session.commit()
with db.begin():
db.session.delete(self.session)
def is_trusted_user(self):
return self.AccountType.ID in {

View file

@ -43,8 +43,6 @@ async def passreset_post(request: Request,
resetkey: str = Form(default=None),
password: str = Form(default=None),
confirm: str = Form(default=None)):
from aurweb.db import session
context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against
@ -86,12 +84,11 @@ async def passreset_post(request: Request,
# We got to this point; everything matched up. Update the password
# and remove the ResetKey.
user.ResetKey = str()
user.update_password(password)
if user.session:
session.delete(user.session)
session.commit()
with db.begin():
user.ResetKey = str()
if user.session:
db.session.delete(user.session)
user.update_password(password)
# Render ?step=complete.
return RedirectResponse(url="/passreset?step=complete",
@ -99,8 +96,8 @@ async def passreset_post(request: Request,
# If we got here, we continue with issuing a resetkey for the user.
resetkey = db.make_random_value(User, User.ResetKey)
user.ResetKey = resetkey
session.commit()
with db.begin():
user.ResetKey = resetkey
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
ResetKeyNotification(executor, user.ID).send()
@ -364,8 +361,6 @@ async def account_register_post(request: Request,
ON: bool = Form(default=False),
captcha: str = Form(default=None),
captcha_salt: str = Form(...)):
from aurweb.db import session
context = await make_variable_context(request, "Register")
args = dict(await request.form())
@ -394,11 +389,13 @@ async def account_register_post(request: Request,
AccountType.AccountType == "User").first()
# Create a user given all parameters available.
user = db.create(User, Username=U, Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON, ResetKey=resetkey,
AccountType=account_type)
with db.begin():
user = db.create(User, Username=U,
Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON,
ResetKey=resetkey, AccountType=account_type)
# If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey.
@ -410,10 +407,10 @@ async def account_register_post(request: Request,
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
session.commit()
with db.begin():
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
# Send a reset key notification to the new user.
executor = db.ConnectionExecutor(db.get_engine().raw_connection())
@ -499,63 +496,67 @@ async def account_edit_post(request: Request,
status_code=int(HTTPStatus.BAD_REQUEST))
# Set all updated fields as needed.
user.Username = U or user.Username
user.Email = E or user.Email
user.HideEmail = bool(H)
user.BackupEmail = BE or user.BackupEmail
user.RealName = R or user.RealName
user.Homepage = HP or user.Homepage
user.IRCNick = I or user.IRCNick
user.PGPKey = K or user.PGPKey
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
with db.begin():
user.Username = U or user.Username
user.Email = E or user.Email
user.HideEmail = bool(H)
user.BackupEmail = BE or user.BackupEmail
user.RealName = R or user.RealName
user.Homepage = HP or user.Homepage
user.IRCNick = I or user.IRCNick
user.PGPKey = K or user.PGPKey
user.InactivityTS = datetime.utcnow().timestamp() if J else 0
# If we update the language, update the cookie as well.
if L and L != user.LangPreference:
request.cookies["AURLANG"] = L
user.LangPreference = L
with db.begin():
user.LangPreference = L
context["language"] = L
# If we update the timezone, also update the cookie.
if TZ and TZ != user.Timezone:
user.Timezone = TZ
with db.begin():
user.Timezone = TZ
request.cookies["AURTZ"] = TZ
context["timezone"] = TZ
user.CommentNotify = bool(CN)
user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON)
with db.begin():
user.CommentNotify = bool(CN)
user.UpdateNotify = bool(UN)
user.OwnershipNotify = bool(ON)
# If a PK is given, compare it against the target user's PK.
if PK:
# Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip()
parts = pubkey.split(" ")
if len(parts) == 3:
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
if not user.ssh_pub_key:
# No public key exists, create one.
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
elif user.ssh_pub_key.PubKey != pubkey:
# A public key already exists, update it.
user.ssh_pub_key.PubKey = pubkey
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
# Commit changes, if any.
session.commit()
with db.begin():
if PK:
# Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip()
parts = pubkey.split(" ")
if len(parts) == 3:
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
if not user.ssh_pub_key:
# No public key exists, create one.
user.ssh_pub_key = SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
elif user.ssh_pub_key.PubKey != pubkey:
# A public key already exists, update it.
user.ssh_pub_key.PubKey = pubkey
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
session.delete(user.ssh_pub_key)
if P and not user.valid_password(P):
# Remove the fields we consumed for passwords.
context["P"] = context["C"] = str()
# If a password was given and it doesn't match the user's, update it.
user.update_password(P)
with db.begin():
user.update_password(P)
if user == request.user:
# If the target user is the request user, login with
# the updated password and update AURSID.
@ -731,21 +732,17 @@ async def terms_of_service_post(request: Request,
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed)
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
accepted_term = request.user.accepted_terms.filter(
AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
with db.begin():
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
accepted_term = request.user.accepted_terms.filter(
AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it!
for term in unaccepted:
db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision,
autocommit=False)
if diffs or unaccepted:
# If we had any terms to update, commit the changes.
db.commit()
# For each term that was never accepted, accept it!
for term in unaccepted:
db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision)
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))

View file

@ -44,8 +44,6 @@ async def language(request: Request,
setting the language on any page, we want to preserve query
parameters across the redirect.
"""
from aurweb.db import session
if next[0] != '/':
return HTMLResponse(b"Invalid 'next' parameter.", status_code=400)
@ -53,8 +51,8 @@ async def language(request: Request,
# If the user is authenticated, update the user's LangPreference.
if request.user.is_authenticated():
request.user.LangPreference = set_lang
session.commit()
with db.begin():
request.user.LangPreference = set_lang
# In any case, set the response's AURLANG cookie that never expires.
response = RedirectResponse(url=f"{next}{query_string}",

View file

@ -214,10 +214,9 @@ async def trusted_user_proposal_post(request: Request,
return Response("Invalid 'decision' value.",
status_code=int(HTTPStatus.BAD_REQUEST))
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo,
autocommit=False)
voteinfo.ActiveTUs += 1
db.commit()
with db.begin():
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo)
voteinfo.ActiveTUs += 1
context["error"] = "You've already voted for this proposal."
return render_proposal(request, context, proposal, voteinfo, voters, vote)
@ -294,12 +293,13 @@ async def trusted_user_addvote_post(request: Request,
agenda = re.sub(r'<[/]?style.*>', '', agenda)
# Create a new TUVoteInfo (proposal)!
voteinfo = db.create(TUVoteInfo,
User=user,
Agenda=agenda,
Submitted=timestamp, End=timestamp + duration,
Quorum=quorum,
Submitter=request.user)
with db.begin():
voteinfo = db.create(TUVoteInfo,
User=user,
Agenda=agenda,
Submitted=timestamp, End=timestamp + duration,
Quorum=quorum,
Submitter=request.user)
# Redirect to the new proposal.
return RedirectResponse(f"/tu/{voteinfo.ID}",