diff --git a/aurweb/asgi.py b/aurweb/asgi.py index 45638277..65318907 100644 --- a/aurweb/asgi.py +++ b/aurweb/asgi.py @@ -3,8 +3,9 @@ import http import typing from fastapi import FastAPI, HTTPException, Request -from fastapi.responses import HTMLResponse +from fastapi.responses import HTMLResponse, RedirectResponse from fastapi.staticfiles import StaticFiles +from sqlalchemy import and_, or_ from starlette.middleware.authentication import AuthenticationMiddleware from starlette.middleware.sessions import SessionMiddleware @@ -12,7 +13,9 @@ import aurweb.config import aurweb.logging from aurweb.auth import BasicAuthBackend -from aurweb.db import get_engine +from aurweb.db import get_engine, query +from aurweb.models.accepted_term import AcceptedTerm +from aurweb.models.term import Term from aurweb.routers import accounts, auth, errors, html, sso # Setup the FastAPI app. @@ -97,3 +100,22 @@ async def add_security_headers(request: Request, call_next: typing.Callable): response.headers["X-Frame-Options"] = xfo return response + + +@app.middleware("http") +async def check_terms_of_service(request: Request, call_next: typing.Callable): + """ This middleware function redirects authenticated users if they + have any outstanding Terms to agree to. """ + if request.user.is_authenticated() and request.url.path != "/tos": + unaccepted = query(Term).join(AcceptedTerm).filter( + or_(AcceptedTerm.UsersID != request.user.ID, + and_(AcceptedTerm.UsersID == request.user.ID, + AcceptedTerm.TermsID == Term.ID, + AcceptedTerm.Revision < Term.Revision))) + if query(Term).count() > unaccepted.count(): + return RedirectResponse( + "/tos", status_code=int(http.HTTPStatus.SEE_OTHER)) + + task = asyncio.create_task(call_next(request)) + await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED) + return task.result() diff --git a/aurweb/routers/accounts.py b/aurweb/routers/accounts.py index 966f8409..3e3469ca 100644 --- a/aurweb/routers/accounts.py +++ b/aurweb/routers/accounts.py @@ -1,4 +1,5 @@ import copy +import typing from datetime import datetime from http import HTTPStatus @@ -13,9 +14,11 @@ from aurweb import db, l10n, time, util from aurweb.auth import auth_required from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token from aurweb.l10n import get_translator_for_request +from aurweb.models.accepted_term import AcceptedTerm from aurweb.models.account_type import AccountType from aurweb.models.ban import Ban from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint +from aurweb.models.term import Term from aurweb.models.user import User from aurweb.scripts.notify import ResetKeyNotification from aurweb.templates import make_variable_context, render_template @@ -576,3 +579,77 @@ async def account(request: Request, username: str): context["user"] = user return render_template(request, "account/show.html", context) + + +def render_terms_of_service(request: Request, + context: dict, + terms: typing.Iterable): + if not terms: + return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER)) + context["unaccepted_terms"] = terms + return render_template(request, "tos/index.html", context) + + +@router.get("/tos") +@auth_required(True, redirect="/") +async def terms_of_service(request: Request): + # Query the database for terms that were previously accepted, + # but now have a bumped Revision that needs to be accepted. + diffs = db.query(Term).join(AcceptedTerm).filter( + AcceptedTerm.Revision < Term.Revision).all() + + # Query the database for any terms that have not yet been accepted. + unaccepted = db.query(Term).filter( + ~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all() + + # Translate the 'Terms of Service' part of our page title. + _ = l10n.get_translator_for_request(request) + title = f"AUR {_('Terms of Service')}" + context = await make_variable_context(request, title) + + accept_needed = sorted(unaccepted + diffs) + return render_terms_of_service(request, context, accept_needed) + + +@router.post("/tos") +@auth_required(True, redirect="/") +async def terms_of_service_post(request: Request, + accept: bool = Form(default=False)): + # Query the database for terms that were previously accepted, + # but now have a bumped Revision that needs to be accepted. + diffs = db.query(Term).join(AcceptedTerm).filter( + AcceptedTerm.Revision < Term.Revision).all() + + # Query the database for any terms that have not yet been accepted. + unaccepted = db.query(Term).filter( + ~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all() + + if not accept: + # Translate the 'Terms of Service' part of our page title. + _ = l10n.get_translator_for_request(request) + title = f"AUR {_('Terms of Service')}" + context = await make_variable_context(request, title) + + # We already did the database filters here, so let's just use + # them instead of reiterating the process in terms_of_service. + accept_needed = sorted(unaccepted + diffs) + return render_terms_of_service(request, context, accept_needed) + + # For each term we found, query for the matching accepted term + # and update its Revision to the term's current Revision. + for term in diffs: + accepted_term = request.user.accepted_terms.filter( + AcceptedTerm.TermsID == term.ID).first() + accepted_term.Revision = term.Revision + + # For each term that was never accepted, accept it! + for term in unaccepted: + db.create(AcceptedTerm, User=request.user, + Term=term, Revision=term.Revision, + autocommit=False) + + if diffs or unaccepted: + # If we had any terms to update, commit the changes. + db.commit() + + return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER)) diff --git a/templates/tos/index.html b/templates/tos/index.html new file mode 100644 index 00000000..a084034e --- /dev/null +++ b/templates/tos/index.html @@ -0,0 +1,46 @@ +{% extends "partials/layout.html" %} + +{% block pageContent %} +
+

AUR {% trans %}Terms of Service{% endtrans %}

+
+
+

+ {{ + "Logged-in as: %s" + | tr | format( + "" + request.user.Username + "") + | safe + }} +

+

+ {{ + "The following documents have been updated. " + "Please review them carefully:" | tr + }} +

+
    + {% for term in unaccepted_terms %} +
  • + {{ term.Description }} + (revision {{ term.Revision }}) +
  • + {% endfor %} +
+

+ {% for term in unaccepted_terms %} + + {% endfor %} + + {{ "I accept the terms and conditions above." | tr }} +

+

+ +

+
+
+
+{% endblock %} diff --git a/test/test_accounts_routes.py b/test/test_accounts_routes.py index bd0d9d4b..96ee3be8 100644 --- a/test/test_accounts_routes.py +++ b/test/test_accounts_routes.py @@ -12,11 +12,13 @@ from fastapi.testclient import TestClient from aurweb import captcha from aurweb.asgi import app -from aurweb.db import create, query +from aurweb.db import commit, create, query +from aurweb.models.accepted_term import AcceptedTerm from aurweb.models.account_type import AccountType from aurweb.models.ban import Ban from aurweb.models.session import Session from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint +from aurweb.models.term import Term from aurweb.models.user import User from aurweb.testing import setup_test_db from aurweb.testing.requests import Request @@ -48,7 +50,7 @@ def make_ssh_pubkey(): def setup(): global user - setup_test_db("Users", "Sessions", "Bans") + setup_test_db("Users", "Sessions", "Bans", "Terms", "AcceptedTerms") account_type = query(AccountType, AccountType.AccountType == "User").first() @@ -919,3 +921,110 @@ def test_get_account_unauthenticated(): content = response.content.decode() assert "You must log in to view user information." in content + + +def test_get_terms_of_service(): + term = create(Term, Description="Test term.", + URL="http://localhost", Revision=1) + + with client as request: + response = request.get("/tos", allow_redirects=False) + assert response.status_code == int(HTTPStatus.SEE_OTHER) + + request = Request() + sid = user.login(request, "testPassword") + cookies = {"AURSID": sid} + + # First of all, let's test that we get redirected to /tos + # when attempting to browse authenticated without accepting terms. + with client as request: + response = request.get("/", cookies=cookies, allow_redirects=False) + assert response.status_code == int(HTTPStatus.SEE_OTHER) + assert response.headers.get("location") == "/tos" + + with client as request: + response = request.get("/tos", cookies=cookies, allow_redirects=False) + assert response.status_code == int(HTTPStatus.OK) + + accepted_term = create(AcceptedTerm, User=user, + Term=term, Revision=term.Revision) + + with client as request: + response = request.get("/tos", cookies=cookies, allow_redirects=False) + # We accepted the term, there's nothing left to accept. + assert response.status_code == int(HTTPStatus.SEE_OTHER) + + # Bump the term's revision. + term.Revision = 2 + commit() + + with client as request: + response = request.get("/tos", cookies=cookies, allow_redirects=False) + # This time, we have a modified term Revision that hasn't + # yet been agreed to via AcceptedTerm update. + assert response.status_code == int(HTTPStatus.OK) + + accepted_term.Revision = term.Revision + commit() + + with client as request: + response = request.get("/tos", cookies=cookies, allow_redirects=False) + # We updated the term revision, there's nothing left to accept. + assert response.status_code == int(HTTPStatus.SEE_OTHER) + + +def test_post_terms_of_service(): + request = Request() + sid = user.login(request, "testPassword") + + data = {"accept": True} # POST data. + cookies = {"AURSID": sid} # Auth cookie. + + # Create a fresh Term. + term = create(Term, Description="Test term.", + URL="http://localhost", Revision=1) + + # Test that the term we just created is listed. + with client as request: + response = request.get("/tos", cookies=cookies) + assert response.status_code == int(HTTPStatus.OK) + + # Make a POST request to /tos with the agree checkbox disabled (False). + with client as request: + response = request.post("/tos", data={"accept": False}, + cookies=cookies) + assert response.status_code == int(HTTPStatus.OK) + + # Make a POST request to /tos with the agree checkbox enabled (True). + with client as request: + response = request.post("/tos", data=data, cookies=cookies) + assert response.status_code == int(HTTPStatus.SEE_OTHER) + + # Query the db for the record created by the post request. + accepted_term = query(AcceptedTerm, + AcceptedTerm.TermsID == term.ID).first() + assert accepted_term.User == user + assert accepted_term.Term == term + + # Update the term to revision 2. + term.Revision = 2 + commit() + + # A GET request gives us the new revision to accept. + with client as request: + response = request.get("/tos", cookies=cookies) + assert response.status_code == int(HTTPStatus.OK) + + # Let's POST again and agree to the new term revision. + with client as request: + response = request.post("/tos", data=data, cookies=cookies) + assert response.status_code == int(HTTPStatus.SEE_OTHER) + + # Check that the records ended up matching. + assert accepted_term.Revision == term.Revision + + # Now, see that GET redirects us to / with no terms left to accept. + with client as request: + response = request.get("/tos", cookies=cookies, allow_redirects=False) + assert response.status_code == int(HTTPStatus.SEE_OTHER) + assert response.headers.get("location") == "/"