mirror of
https://gitlab.archlinux.org/archlinux/aurweb.git
synced 2025-02-03 10:43:03 +01:00
[FastAPI] add /tos routes (get, post)
This clones the end goal behavior of PHP, but it does not concern itself with the revision form array at all. Since this page on PHP renders out the entire list of terms that a user needs to accept, we can treat a POST request with the "accept" checkbox enabled as a request to accept all unaccepted (or outdated revision) terms. This commit also adds in a new http middleware used to redirect authenticated users to '/tos' if they have not yet accepted all terms. Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
parent
e624e25c0f
commit
adb42882c5
4 changed files with 258 additions and 4 deletions
|
@ -3,8 +3,9 @@ import http
|
||||||
import typing
|
import typing
|
||||||
|
|
||||||
from fastapi import FastAPI, HTTPException, Request
|
from fastapi import FastAPI, HTTPException, Request
|
||||||
from fastapi.responses import HTMLResponse
|
from fastapi.responses import HTMLResponse, RedirectResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
|
from sqlalchemy import and_, or_
|
||||||
from starlette.middleware.authentication import AuthenticationMiddleware
|
from starlette.middleware.authentication import AuthenticationMiddleware
|
||||||
from starlette.middleware.sessions import SessionMiddleware
|
from starlette.middleware.sessions import SessionMiddleware
|
||||||
|
|
||||||
|
@ -12,7 +13,9 @@ import aurweb.config
|
||||||
import aurweb.logging
|
import aurweb.logging
|
||||||
|
|
||||||
from aurweb.auth import BasicAuthBackend
|
from aurweb.auth import BasicAuthBackend
|
||||||
from aurweb.db import get_engine
|
from aurweb.db import get_engine, query
|
||||||
|
from aurweb.models.accepted_term import AcceptedTerm
|
||||||
|
from aurweb.models.term import Term
|
||||||
from aurweb.routers import accounts, auth, errors, html, sso
|
from aurweb.routers import accounts, auth, errors, html, sso
|
||||||
|
|
||||||
# Setup the FastAPI app.
|
# Setup the FastAPI app.
|
||||||
|
@ -97,3 +100,22 @@ async def add_security_headers(request: Request, call_next: typing.Callable):
|
||||||
response.headers["X-Frame-Options"] = xfo
|
response.headers["X-Frame-Options"] = xfo
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@app.middleware("http")
|
||||||
|
async def check_terms_of_service(request: Request, call_next: typing.Callable):
|
||||||
|
""" This middleware function redirects authenticated users if they
|
||||||
|
have any outstanding Terms to agree to. """
|
||||||
|
if request.user.is_authenticated() and request.url.path != "/tos":
|
||||||
|
unaccepted = query(Term).join(AcceptedTerm).filter(
|
||||||
|
or_(AcceptedTerm.UsersID != request.user.ID,
|
||||||
|
and_(AcceptedTerm.UsersID == request.user.ID,
|
||||||
|
AcceptedTerm.TermsID == Term.ID,
|
||||||
|
AcceptedTerm.Revision < Term.Revision)))
|
||||||
|
if query(Term).count() > unaccepted.count():
|
||||||
|
return RedirectResponse(
|
||||||
|
"/tos", status_code=int(http.HTTPStatus.SEE_OTHER))
|
||||||
|
|
||||||
|
task = asyncio.create_task(call_next(request))
|
||||||
|
await asyncio.wait({task}, return_when=asyncio.FIRST_COMPLETED)
|
||||||
|
return task.result()
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import copy
|
import copy
|
||||||
|
import typing
|
||||||
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
|
@ -13,9 +14,11 @@ from aurweb import db, l10n, time, util
|
||||||
from aurweb.auth import auth_required
|
from aurweb.auth import auth_required
|
||||||
from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token
|
from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token
|
||||||
from aurweb.l10n import get_translator_for_request
|
from aurweb.l10n import get_translator_for_request
|
||||||
|
from aurweb.models.accepted_term import AcceptedTerm
|
||||||
from aurweb.models.account_type import AccountType
|
from aurweb.models.account_type import AccountType
|
||||||
from aurweb.models.ban import Ban
|
from aurweb.models.ban import Ban
|
||||||
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
|
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
|
||||||
|
from aurweb.models.term import Term
|
||||||
from aurweb.models.user import User
|
from aurweb.models.user import User
|
||||||
from aurweb.scripts.notify import ResetKeyNotification
|
from aurweb.scripts.notify import ResetKeyNotification
|
||||||
from aurweb.templates import make_variable_context, render_template
|
from aurweb.templates import make_variable_context, render_template
|
||||||
|
@ -576,3 +579,77 @@ async def account(request: Request, username: str):
|
||||||
context["user"] = user
|
context["user"] = user
|
||||||
|
|
||||||
return render_template(request, "account/show.html", context)
|
return render_template(request, "account/show.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
def render_terms_of_service(request: Request,
|
||||||
|
context: dict,
|
||||||
|
terms: typing.Iterable):
|
||||||
|
if not terms:
|
||||||
|
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))
|
||||||
|
context["unaccepted_terms"] = terms
|
||||||
|
return render_template(request, "tos/index.html", context)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/tos")
|
||||||
|
@auth_required(True, redirect="/")
|
||||||
|
async def terms_of_service(request: Request):
|
||||||
|
# Query the database for terms that were previously accepted,
|
||||||
|
# but now have a bumped Revision that needs to be accepted.
|
||||||
|
diffs = db.query(Term).join(AcceptedTerm).filter(
|
||||||
|
AcceptedTerm.Revision < Term.Revision).all()
|
||||||
|
|
||||||
|
# Query the database for any terms that have not yet been accepted.
|
||||||
|
unaccepted = db.query(Term).filter(
|
||||||
|
~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all()
|
||||||
|
|
||||||
|
# Translate the 'Terms of Service' part of our page title.
|
||||||
|
_ = l10n.get_translator_for_request(request)
|
||||||
|
title = f"AUR {_('Terms of Service')}"
|
||||||
|
context = await make_variable_context(request, title)
|
||||||
|
|
||||||
|
accept_needed = sorted(unaccepted + diffs)
|
||||||
|
return render_terms_of_service(request, context, accept_needed)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/tos")
|
||||||
|
@auth_required(True, redirect="/")
|
||||||
|
async def terms_of_service_post(request: Request,
|
||||||
|
accept: bool = Form(default=False)):
|
||||||
|
# Query the database for terms that were previously accepted,
|
||||||
|
# but now have a bumped Revision that needs to be accepted.
|
||||||
|
diffs = db.query(Term).join(AcceptedTerm).filter(
|
||||||
|
AcceptedTerm.Revision < Term.Revision).all()
|
||||||
|
|
||||||
|
# Query the database for any terms that have not yet been accepted.
|
||||||
|
unaccepted = db.query(Term).filter(
|
||||||
|
~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all()
|
||||||
|
|
||||||
|
if not accept:
|
||||||
|
# Translate the 'Terms of Service' part of our page title.
|
||||||
|
_ = l10n.get_translator_for_request(request)
|
||||||
|
title = f"AUR {_('Terms of Service')}"
|
||||||
|
context = await make_variable_context(request, title)
|
||||||
|
|
||||||
|
# We already did the database filters here, so let's just use
|
||||||
|
# them instead of reiterating the process in terms_of_service.
|
||||||
|
accept_needed = sorted(unaccepted + diffs)
|
||||||
|
return render_terms_of_service(request, context, accept_needed)
|
||||||
|
|
||||||
|
# For each term we found, query for the matching accepted term
|
||||||
|
# and update its Revision to the term's current Revision.
|
||||||
|
for term in diffs:
|
||||||
|
accepted_term = request.user.accepted_terms.filter(
|
||||||
|
AcceptedTerm.TermsID == term.ID).first()
|
||||||
|
accepted_term.Revision = term.Revision
|
||||||
|
|
||||||
|
# For each term that was never accepted, accept it!
|
||||||
|
for term in unaccepted:
|
||||||
|
db.create(AcceptedTerm, User=request.user,
|
||||||
|
Term=term, Revision=term.Revision,
|
||||||
|
autocommit=False)
|
||||||
|
|
||||||
|
if diffs or unaccepted:
|
||||||
|
# If we had any terms to update, commit the changes.
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return RedirectResponse("/", status_code=int(HTTPStatus.SEE_OTHER))
|
||||||
|
|
46
templates/tos/index.html
Normal file
46
templates/tos/index.html
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
{% extends "partials/layout.html" %}
|
||||||
|
|
||||||
|
{% block pageContent %}
|
||||||
|
<div id="dev-login" class="box">
|
||||||
|
<h2>AUR {% trans %}Terms of Service{% endtrans %}</h2>
|
||||||
|
<form method="post" action="/tos">
|
||||||
|
<fieldset>
|
||||||
|
<p>
|
||||||
|
{{
|
||||||
|
"Logged-in as: %s"
|
||||||
|
| tr | format(
|
||||||
|
"<strong>" + request.user.Username + "</strong>")
|
||||||
|
| safe
|
||||||
|
}}
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
{{
|
||||||
|
"The following documents have been updated. "
|
||||||
|
"Please review them carefully:" | tr
|
||||||
|
}}
|
||||||
|
</p>
|
||||||
|
<ul>
|
||||||
|
{% for term in unaccepted_terms %}
|
||||||
|
<li>
|
||||||
|
<a href="{{ term.URL }}">{{ term.Description }}</a>
|
||||||
|
(revision {{ term.Revision }})
|
||||||
|
</li>
|
||||||
|
{% endfor %}
|
||||||
|
</ul>
|
||||||
|
<p>
|
||||||
|
{% for term in unaccepted_terms %}
|
||||||
|
<input type="hidden"
|
||||||
|
name="rev[{{ loop.index }}]"
|
||||||
|
value="{{ term.Revision }}" />
|
||||||
|
{% endfor %}
|
||||||
|
<input type="checkbox" name="accept" />
|
||||||
|
{{ "I accept the terms and conditions above." | tr }}
|
||||||
|
</p>
|
||||||
|
<p>
|
||||||
|
<input type="submit" name="submit"
|
||||||
|
value="{{ 'Continue' | tr }}" />
|
||||||
|
</p>
|
||||||
|
</fieldset>
|
||||||
|
</form>
|
||||||
|
</div>
|
||||||
|
{% endblock %}
|
|
@ -12,11 +12,13 @@ from fastapi.testclient import TestClient
|
||||||
|
|
||||||
from aurweb import captcha
|
from aurweb import captcha
|
||||||
from aurweb.asgi import app
|
from aurweb.asgi import app
|
||||||
from aurweb.db import create, query
|
from aurweb.db import commit, create, query
|
||||||
|
from aurweb.models.accepted_term import AcceptedTerm
|
||||||
from aurweb.models.account_type import AccountType
|
from aurweb.models.account_type import AccountType
|
||||||
from aurweb.models.ban import Ban
|
from aurweb.models.ban import Ban
|
||||||
from aurweb.models.session import Session
|
from aurweb.models.session import Session
|
||||||
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
|
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
|
||||||
|
from aurweb.models.term import Term
|
||||||
from aurweb.models.user import User
|
from aurweb.models.user import User
|
||||||
from aurweb.testing import setup_test_db
|
from aurweb.testing import setup_test_db
|
||||||
from aurweb.testing.requests import Request
|
from aurweb.testing.requests import Request
|
||||||
|
@ -48,7 +50,7 @@ def make_ssh_pubkey():
|
||||||
def setup():
|
def setup():
|
||||||
global user
|
global user
|
||||||
|
|
||||||
setup_test_db("Users", "Sessions", "Bans")
|
setup_test_db("Users", "Sessions", "Bans", "Terms", "AcceptedTerms")
|
||||||
|
|
||||||
account_type = query(AccountType,
|
account_type = query(AccountType,
|
||||||
AccountType.AccountType == "User").first()
|
AccountType.AccountType == "User").first()
|
||||||
|
@ -919,3 +921,110 @@ def test_get_account_unauthenticated():
|
||||||
|
|
||||||
content = response.content.decode()
|
content = response.content.decode()
|
||||||
assert "You must log in to view user information." in content
|
assert "You must log in to view user information." in content
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_terms_of_service():
|
||||||
|
term = create(Term, Description="Test term.",
|
||||||
|
URL="http://localhost", Revision=1)
|
||||||
|
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/tos", allow_redirects=False)
|
||||||
|
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
||||||
|
|
||||||
|
request = Request()
|
||||||
|
sid = user.login(request, "testPassword")
|
||||||
|
cookies = {"AURSID": sid}
|
||||||
|
|
||||||
|
# First of all, let's test that we get redirected to /tos
|
||||||
|
# when attempting to browse authenticated without accepting terms.
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/", cookies=cookies, allow_redirects=False)
|
||||||
|
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
||||||
|
assert response.headers.get("location") == "/tos"
|
||||||
|
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/tos", cookies=cookies, allow_redirects=False)
|
||||||
|
assert response.status_code == int(HTTPStatus.OK)
|
||||||
|
|
||||||
|
accepted_term = create(AcceptedTerm, User=user,
|
||||||
|
Term=term, Revision=term.Revision)
|
||||||
|
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/tos", cookies=cookies, allow_redirects=False)
|
||||||
|
# We accepted the term, there's nothing left to accept.
|
||||||
|
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
||||||
|
|
||||||
|
# Bump the term's revision.
|
||||||
|
term.Revision = 2
|
||||||
|
commit()
|
||||||
|
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/tos", cookies=cookies, allow_redirects=False)
|
||||||
|
# This time, we have a modified term Revision that hasn't
|
||||||
|
# yet been agreed to via AcceptedTerm update.
|
||||||
|
assert response.status_code == int(HTTPStatus.OK)
|
||||||
|
|
||||||
|
accepted_term.Revision = term.Revision
|
||||||
|
commit()
|
||||||
|
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/tos", cookies=cookies, allow_redirects=False)
|
||||||
|
# We updated the term revision, there's nothing left to accept.
|
||||||
|
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
||||||
|
|
||||||
|
|
||||||
|
def test_post_terms_of_service():
|
||||||
|
request = Request()
|
||||||
|
sid = user.login(request, "testPassword")
|
||||||
|
|
||||||
|
data = {"accept": True} # POST data.
|
||||||
|
cookies = {"AURSID": sid} # Auth cookie.
|
||||||
|
|
||||||
|
# Create a fresh Term.
|
||||||
|
term = create(Term, Description="Test term.",
|
||||||
|
URL="http://localhost", Revision=1)
|
||||||
|
|
||||||
|
# Test that the term we just created is listed.
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/tos", cookies=cookies)
|
||||||
|
assert response.status_code == int(HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Make a POST request to /tos with the agree checkbox disabled (False).
|
||||||
|
with client as request:
|
||||||
|
response = request.post("/tos", data={"accept": False},
|
||||||
|
cookies=cookies)
|
||||||
|
assert response.status_code == int(HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Make a POST request to /tos with the agree checkbox enabled (True).
|
||||||
|
with client as request:
|
||||||
|
response = request.post("/tos", data=data, cookies=cookies)
|
||||||
|
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
||||||
|
|
||||||
|
# Query the db for the record created by the post request.
|
||||||
|
accepted_term = query(AcceptedTerm,
|
||||||
|
AcceptedTerm.TermsID == term.ID).first()
|
||||||
|
assert accepted_term.User == user
|
||||||
|
assert accepted_term.Term == term
|
||||||
|
|
||||||
|
# Update the term to revision 2.
|
||||||
|
term.Revision = 2
|
||||||
|
commit()
|
||||||
|
|
||||||
|
# A GET request gives us the new revision to accept.
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/tos", cookies=cookies)
|
||||||
|
assert response.status_code == int(HTTPStatus.OK)
|
||||||
|
|
||||||
|
# Let's POST again and agree to the new term revision.
|
||||||
|
with client as request:
|
||||||
|
response = request.post("/tos", data=data, cookies=cookies)
|
||||||
|
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
||||||
|
|
||||||
|
# Check that the records ended up matching.
|
||||||
|
assert accepted_term.Revision == term.Revision
|
||||||
|
|
||||||
|
# Now, see that GET redirects us to / with no terms left to accept.
|
||||||
|
with client as request:
|
||||||
|
response = request.get("/tos", cookies=cookies, allow_redirects=False)
|
||||||
|
assert response.status_code == int(HTTPStatus.SEE_OTHER)
|
||||||
|
assert response.headers.get("location") == "/"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue