diff --git a/aurweb/asgi.py b/aurweb/asgi.py index 45638277..65318907 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -3,8 +3,9 @@ import http import typing from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles +from sqlalchemy import and_, or_ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -12,7 +13,9 @@ import aurweb.config import aurweb.logging from aurweb.auth import BasicAuthBackend -from aurweb.db import get_engine +from aurweb.db import get_engine, query +from aurweb.models.accepted_term import AcceptedTerm +from aurweb.models.term import Term from aurweb.routers import accounts, auth, errors, html, sso # Setup the FastAPI app. @@ -97,3 +100,22 @@ async def add_security_headers(request: Request, call_next: typing.Callable): response.headers["X-Frame-Options"] = xfo return response + + +@app.middleware("http") +async def check_terms_of_service(request: Request, call_next: typing.Callable): + """ This middleware function redirects authenticated users if they + have any outstanding Terms to agree to. """ + if request.user.is_authenticated() and request.url.path != "/tos": + unaccepted = query(Term).join(AcceptedTerm).filter( + or_(AcceptedTerm.UsersID != request.user.ID, + and_(AcceptedTerm.UsersID == request.user.ID, + AcceptedTerm.TermsID == Term.ID, + AcceptedTerm.Revision < Term.Revision))) + if query(Term).count() > unaccepted.count(): + return RedirectResponse( + "/tos", status_code=int(http.HTTPStatus.SEE_OTHER)) + + task = asyncio.create_task(call_next(request)) + await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED) + return task.result() diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index 966f8409..3e3469ca 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -1,4 +1,5 @@ import copy +import typing from datetime import datetime from http import HTTPStatus @@ -13,9 +14,11 @@ from aurweb import db, l10n, time, util from aurweb.auth import auth_required from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token from aurweb.l10n import get_translator_for_request +from aurweb.models.accepted_term import AcceptedTerm from aurweb.models.account_type import AccountType from aurweb.models.ban import Ban from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint +from aurweb.models.term import Term from aurweb.models.user import User from aurweb.scripts.notify import ResetKeyNotification from aurweb.templates import make_variable_context, render_template @@ -576,3 +579,77 @@ async def account(request: Request, username: str): context["user"] = user return render_template(request, "account/show.html", context) + + +def render_terms_of_service(request: Request, + context: dict, + terms: typing.Iterable): + if not terms: + return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER)) + context["unaccepted_terms"] = terms + return render_template(request, "tos/index.html", context) + + +@router.get("/tos") +@auth_required(True, redirect="/") +async def terms_of_service(request: Request): + # Query the database for terms that were previously accepted, + # but now have a bumped Revision that needs to be accepted. + diffs = db.query(Term).join(AcceptedTerm).filter( + AcceptedTerm.Revision < Term.Revision).all() + + # Query the database for any terms that have not yet been accepted. + unaccepted = db.query(Term).filter( + ~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all() + + # Translate the 'Terms of Service' part of our page title. + _ = l10n.get_translator_for_request(request) + title = f"AUR {_('Terms of Service')}" + context = await make_variable_context(request, title) + + accept_needed = sorted(unaccepted + diffs) + return render_terms_of_service(request, context, accept_needed) + + +@router.post("/tos") +@auth_required(True, redirect="/") +async def terms_of_service_post(request: Request, + accept: bool = Form(default=False)): + # Query the database for terms that were previously accepted, + # but now have a bumped Revision that needs to be accepted. + diffs = db.query(Term).join(AcceptedTerm).filter( + AcceptedTerm.Revision < Term.Revision).all() + + # Query the database for any terms that have not yet been accepted. + unaccepted = db.query(Term).filter( + ~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all() + + if not accept: + # Translate the 'Terms of Service' part of our page title. + _ = l10n.get_translator_for_request(request) + title = f"AUR {_('Terms of Service')}" + context = await make_variable_context(request, title) + + # We already did the database filters here, so let's just use + # them instead of reiterating the process in terms_of_service. + accept_needed = sorted(unaccepted + diffs) + return render_terms_of_service(request, context, accept_needed) + + # For each term we found, query for the matching accepted term + # and update its Revision to the term's current Revision. + for term in diffs: + accepted_term = request.user.accepted_terms.filter( + AcceptedTerm.TermsID == term.ID).first() + accepted_term.Revision = term.Revision + + # For each term that was never accepted, accept it! + for term in unaccepted: + db.create(AcceptedTerm, User=request.user, + Term=term, Revision=term.Revision, + autocommit=False) + + if diffs or unaccepted: + # If we had any terms to update, commit the changes. + db.commit() + + return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER)) diff --git a/templates/tos/index.html b/templates/tos/index.html new file mode 100644 index 00000000..a084034e --- /dev/null +++ b/templates/tos/index.html @@ -0,0 +1,46 @@ +{% extends "partials/layout.html" %} + +{% block pageContent %} +