[FastAPI] add /tos routes (get, post)

This clones the end goal behavior of PHP, but it does not
concern itself with the revision form array at all.

Since this page on PHP renders out the entire list of
terms that a user needs to accept, we can treat a
POST request with the "accept" checkbox enabled as a
request to accept all unaccepted (or outdated revision)
terms.

This commit also adds in a new http middleware used to
redirect authenticated users to '/tos' if they have not
yet accepted all terms.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-06-13 12:18:17 -07:00
parent e624e25c0f
commit adb42882c5
4 changed files with 258 additions and 4 deletions

View file

@ -3,8 +3,9 @@ import http
import typing
from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import HTMLResponse
from fastapi.responses import HTMLResponse, RedirectResponse
from fastapi.staticfiles import StaticFiles
from sqlalchemy import and_, or_
from starlette.middleware.authentication import AuthenticationMiddleware
from starlette.middleware.sessions import SessionMiddleware
@ -12,7 +13,9 @@ import aurweb.config
import aurweb.logging
from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine
from aurweb.db import get_engine, query
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.term import Term
from aurweb.routers import accounts, auth, errors, html, sso
# Setup the FastAPI app.
@ -97,3 +100,22 @@ async def add_security_headers(request: Request, call_next: typing.Callable):
response.headers["X-Frame-Options"] = xfo
return response
@app.middleware("http")
async def check_terms_of_service(request: Request, call_next: typing.Callable):
""" This middleware function redirects authenticated users if they
have any outstanding Terms to agree to. """
if request.user.is_authenticated() and request.url.path != "/tos":
unaccepted = query(Term).join(AcceptedTerm).filter(
or_(AcceptedTerm.UsersID != request.user.ID,
and_(AcceptedTerm.UsersID == request.user.ID,
AcceptedTerm.TermsID == Term.ID,
AcceptedTerm.Revision < Term.Revision)))
if query(Term).count() > unaccepted.count():
return RedirectResponse(
"/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
task = asyncio.create_task(call_next(request))
await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
return task.result()

View file

@ -1,4 +1,5 @@
import copy
import typing
from datetime import datetime
from http import HTTPStatus
@ -13,9 +14,11 @@ from aurweb import db, l10n, time, util
from aurweb.auth import auth_required
from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token
from aurweb.l10n import get_translator_for_request
from aurweb.models.accepted_term import AcceptedTerm
from aurweb.models.account_type import AccountType
from aurweb.models.ban import Ban
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
from aurweb.models.term import Term
from aurweb.models.user import User
from aurweb.scripts.notify import ResetKeyNotification
from aurweb.templates import make_variable_context, render_template
@ -576,3 +579,77 @@ async def account(request: Request, username: str):
context["user"] = user
return render_template(request, "account/show.html", context)
def render_terms_of_service(request: Request,
context: dict,
terms: typing.Iterable):
if not terms:
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))
context["unaccepted_terms"] = terms
return render_template(request, "tos/index.html", context)
@router.get("/tos")
@auth_required(True, redirect="/")
async def terms_of_service(request: Request):
# Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted.
diffs = db.query(Term).join(AcceptedTerm).filter(
AcceptedTerm.Revision < Term.Revision).all()
# Query the database for any terms that have not yet been accepted.
unaccepted = db.query(Term).filter(
~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all()
# Translate the 'Terms of Service' part of our page title.
_ = l10n.get_translator_for_request(request)
title = f"AUR {_('Terms of Service')}"
context = await make_variable_context(request, title)
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed)
@router.post("/tos")
@auth_required(True, redirect="/")
async def terms_of_service_post(request: Request,
accept: bool = Form(default=False)):
# Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted.
diffs = db.query(Term).join(AcceptedTerm).filter(
AcceptedTerm.Revision < Term.Revision).all()
# Query the database for any terms that have not yet been accepted.
unaccepted = db.query(Term).filter(
~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all()
if not accept:
# Translate the 'Terms of Service' part of our page title.
_ = l10n.get_translator_for_request(request)
title = f"AUR {_('Terms of Service')}"
context = await make_variable_context(request, title)
# We already did the database filters here, so let's just use
# them instead of reiterating the process in terms_of_service.
accept_needed = sorted(unaccepted + diffs)
return render_terms_of_service(request, context, accept_needed)
# For each term we found, query for the matching accepted term
# and update its Revision to the term's current Revision.
for term in diffs:
accepted_term = request.user.accepted_terms.filter(
AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it!
for term in unaccepted:
db.create(AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision,
autocommit=False)
if diffs or unaccepted:
# If we had any terms to update, commit the changes.
db.commit()
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))