change(fastapi): simplify model imports across code-base

Closes: #133

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-10-16 19:25:25 -07:00
parent bfdc85d7d6
commit 28c4e9697b
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
12 changed files with 341 additions and 343 deletions

View file

@ -16,8 +16,7 @@ import aurweb.logging
from aurweb.auth import BasicAuthBackend from aurweb.auth import BasicAuthBackend
from aurweb.db import get_engine, query from aurweb.db import get_engine, query
from aurweb.models.accepted_term import AcceptedTerm from aurweb.models import AcceptedTerm, Term
from aurweb.models.term import Term
from aurweb.routers import accounts, auth, errors, html, packages, rpc, rss, sso, trusted_user from aurweb.routers import accounts, auth, errors, html, packages, rpc, rss, sso, trusted_user
# Setup the FastAPI app. # Setup the FastAPI app.

View file

@ -14,9 +14,8 @@ from starlette.requests import HTTPConnection
import aurweb.config import aurweb.config
from aurweb import l10n, util from aurweb import l10n, util
from aurweb.models import Session, User
from aurweb.models.account_type import ACCOUNT_TYPE_ID from aurweb.models.account_type import ACCOUNT_TYPE_ID
from aurweb.models.session import Session
from aurweb.models.user import User
from aurweb.templates import make_variable_context, render_template from aurweb.templates import make_variable_context, render_template

View file

@ -4,7 +4,7 @@ import hashlib
from jinja2 import pass_context from jinja2 import pass_context
from aurweb.db import query from aurweb.db import query
from aurweb.models.user import User from aurweb.models import User
def get_captcha_salts(): def get_captcha_salts():

View file

@ -1,13 +1,6 @@
from sqlalchemy import and_, case, or_, orm from sqlalchemy import and_, case, or_, orm
from aurweb import config, db from aurweb import config, db, models
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
from aurweb.models.package_comaintainer import PackageComaintainer
from aurweb.models.package_keyword import PackageKeyword
from aurweb.models.package_notification import PackageNotification
from aurweb.models.package_vote import PackageVote
from aurweb.models.user import User
DEFAULT_MAX_RESULTS = 2500 DEFAULT_MAX_RESULTS = 2500
@ -18,22 +11,22 @@ class PackageSearch:
# A constant mapping of short to full name sort orderings. # A constant mapping of short to full name sort orderings.
FULL_SORT_ORDER = {"d": "desc", "a": "asc"} FULL_SORT_ORDER = {"d": "desc", "a": "asc"}
def __init__(self, user: User): def __init__(self, user: models.User):
""" Construct an instance of PackageSearch. """ Construct an instance of PackageSearch.
This constructors performs several steps during initialization: This constructors performs several steps during initialization:
1. Setup self.query: an ORM query of Package joined by PackageBase. 1. Setup self.query: an ORM query of Package joined by PackageBase.
""" """
self.user = user self.user = user
self.query = db.query(Package).join(PackageBase).join( self.query = db.query(models.Package).join(models.PackageBase).join(
PackageVote, models.PackageVote,
and_(PackageVote.PackageBaseID == PackageBase.ID, and_(models.PackageVote.PackageBaseID == models.PackageBase.ID,
PackageVote.UsersID == self.user.ID), models.PackageVote.UsersID == self.user.ID),
isouter=True isouter=True
).join( ).join(
PackageNotification, models.PackageNotification,
and_(PackageNotification.PackageBaseID == PackageBase.ID, and_(models.PackageNotification.PackageBaseID == models.PackageBase.ID,
PackageNotification.UserID == self.user.ID), models.PackageNotification.UserID == self.user.ID),
isouter=True isouter=True
) )
self.ordering = "d" self.ordering = "d"
@ -65,59 +58,64 @@ class PackageSearch:
def _search_by_namedesc(self, keywords: str) -> orm.Query: def _search_by_namedesc(self, keywords: str) -> orm.Query:
self.query = self.query.filter( self.query = self.query.filter(
or_(Package.Name.like(f"%{keywords}%"), or_(models.Package.Name.like(f"%{keywords}%"),
Package.Description.like(f"%{keywords}%")) models.Package.Description.like(f"%{keywords}%"))
) )
return self return self
def _search_by_name(self, keywords: str) -> orm.Query: def _search_by_name(self, keywords: str) -> orm.Query:
self.query = self.query.filter(Package.Name.like(f"%{keywords}%")) self.query = self.query.filter(
models.Package.Name.like(f"%{keywords}%"))
return self return self
def _search_by_exact_name(self, keywords: str) -> orm.Query: def _search_by_exact_name(self, keywords: str) -> orm.Query:
self.query = self.query.filter(Package.Name == keywords) self.query = self.query.filter(
models.Package.Name == keywords)
return self return self
def _search_by_pkgbase(self, keywords: str) -> orm.Query: def _search_by_pkgbase(self, keywords: str) -> orm.Query:
self.query = self.query.filter(PackageBase.Name.like(f"%{keywords}%")) self.query = self.query.filter(
models.PackageBase.Name.like(f"%{keywords}%"))
return self return self
def _search_by_exact_pkgbase(self, keywords: str) -> orm.Query: def _search_by_exact_pkgbase(self, keywords: str) -> orm.Query:
self.query = self.query.filter(PackageBase.Name == keywords) self.query = self.query.filter(
models.PackageBase.Name == keywords)
return self return self
def _search_by_keywords(self, keywords: str) -> orm.Query: def _search_by_keywords(self, keywords: str) -> orm.Query:
self.query = self.query.join(PackageKeyword).filter( self.query = self.query.join(models.PackageKeyword).filter(
PackageKeyword.Keyword == keywords models.PackageKeyword.Keyword == keywords
) )
return self return self
def _search_by_maintainer(self, keywords: str) -> orm.Query: def _search_by_maintainer(self, keywords: str) -> orm.Query:
self.query = self.query.join( self.query = self.query.join(
User, User.ID == PackageBase.MaintainerUID models.User, models.User.ID == models.PackageBase.MaintainerUID
).filter(User.Username == keywords) ).filter(models.User.Username == keywords)
return self return self
def _search_by_comaintainer(self, keywords: str) -> orm.Query: def _search_by_comaintainer(self, keywords: str) -> orm.Query:
self.query = self.query.join(PackageComaintainer).join( self.query = self.query.join(models.PackageComaintainer).join(
User, User.ID == PackageComaintainer.UsersID models.User, models.User.ID == models.PackageComaintainer.UsersID
).filter(User.Username == keywords) ).filter(models.User.Username == keywords)
return self return self
def _search_by_co_or_maintainer(self, keywords: str) -> orm.Query: def _search_by_co_or_maintainer(self, keywords: str) -> orm.Query:
self.query = self.query.join( self.query = self.query.join(
PackageComaintainer, models.PackageComaintainer,
isouter=True isouter=True
).join( ).join(
User, or_(User.ID == PackageBase.MaintainerUID, models.User,
User.ID == PackageComaintainer.UsersID) or_(models.User.ID == models.PackageBase.MaintainerUID,
).filter(User.Username == keywords) models.User.ID == models.PackageComaintainer.UsersID)
).filter(models.User.Username == keywords)
return self return self
def _search_by_submitter(self, keywords: str) -> orm.Query: def _search_by_submitter(self, keywords: str) -> orm.Query:
self.query = self.query.join( self.query = self.query.join(
User, User.ID == PackageBase.SubmitterUID models.User, models.User.ID == models.PackageBase.SubmitterUID
).filter(User.Username == keywords) ).filter(models.User.Username == keywords)
return self return self
def search_by(self, search_by: str, keywords: str) -> orm.Query: def search_by(self, search_by: str, keywords: str) -> orm.Query:
@ -128,17 +126,17 @@ class PackageSearch:
return result return result
def _sort_by_name(self, order: str): def _sort_by_name(self, order: str):
column = getattr(Package.Name, order) column = getattr(models.Package.Name, order)
self.query = self.query.order_by(column()) self.query = self.query.order_by(column())
return self return self
def _sort_by_votes(self, order: str): def _sort_by_votes(self, order: str):
column = getattr(PackageBase.NumVotes, order) column = getattr(models.PackageBase.NumVotes, order)
self.query = self.query.order_by(column()) self.query = self.query.order_by(column())
return self return self
def _sort_by_popularity(self, order: str): def _sort_by_popularity(self, order: str):
column = getattr(PackageBase.Popularity, order) column = getattr(models.PackageBase.Popularity, order)
self.query = self.query.order_by(column()) self.query = self.query.order_by(column())
return self return self
@ -147,10 +145,10 @@ class PackageSearch:
# in terms of performance. We should improve this; there's no # in terms of performance. We should improve this; there's no
# reason it should take _longer_. # reason it should take _longer_.
column = getattr( column = getattr(
case([(PackageVote.UsersID == self.user.ID, 1)], else_=0), case([(models.PackageVote.UsersID == self.user.ID, 1)], else_=0),
order order
) )
self.query = self.query.order_by(column(), Package.Name.desc()) self.query = self.query.order_by(column(), models.Package.Name.desc())
return self return self
def _sort_by_notify(self, order: str): def _sort_by_notify(self, order: str):
@ -158,21 +156,24 @@ class PackageSearch:
# in terms of performance. We should improve this; there's no # in terms of performance. We should improve this; there's no
# reason it should take _longer_. # reason it should take _longer_.
column = getattr( column = getattr(
case([(PackageNotification.UserID == self.user.ID, 1)], else_=0), case([(models.PackageNotification.UserID == self.user.ID, 1)],
else_=0),
order order
) )
self.query = self.query.order_by(column(), Package.Name.desc()) self.query = self.query.order_by(column(), models.Package.Name.desc())
return self return self
def _sort_by_maintainer(self, order: str): def _sort_by_maintainer(self, order: str):
column = getattr(User.Username, order) column = getattr(models.User.Username, order)
self.query = self.query.join( self.query = self.query.join(
User, User.ID == PackageBase.MaintainerUID, isouter=True models.User,
models.User.ID == models.PackageBase.MaintainerUID,
isouter=True
).order_by(column()) ).order_by(column())
return self return self
def _sort_by_last_modified(self, order: str): def _sort_by_last_modified(self, order: str):
column = getattr(PackageBase.ModifiedTS, order) column = getattr(models.PackageBase.ModifiedTS, order)
self.query = self.query.order_by(column()) self.query = self.query.order_by(column())
return self return self

View file

@ -7,43 +7,35 @@ import orjson
from fastapi import HTTPException from fastapi import HTTPException
from sqlalchemy import and_, orm from sqlalchemy import and_, orm
from aurweb import db from aurweb import db, models
from aurweb.models.official_provider import OFFICIAL_BASE, OfficialProvider from aurweb.models.official_provider import OFFICIAL_BASE
from aurweb.models.package import Package from aurweb.models.relation_type import PROVIDES_ID
from aurweb.models.package_base import PackageBase
from aurweb.models.package_comment import PackageComment
from aurweb.models.package_dependency import PackageDependency
from aurweb.models.package_notification import PackageNotification
from aurweb.models.package_relation import PackageRelation
from aurweb.models.package_vote import PackageVote
from aurweb.models.relation_type import PROVIDES_ID, RelationType
from aurweb.models.user import User
from aurweb.redis import redis_connection from aurweb.redis import redis_connection
from aurweb.templates import register_filter from aurweb.templates import register_filter
def dep_depends_extra(dep: PackageDependency) -> str: def dep_depends_extra(dep: models.PackageDependency) -> str:
""" A function used to produce extra text for dependency display. """ """ A function used to produce extra text for dependency display. """
return str() return str()
def dep_makedepends_extra(dep: PackageDependency) -> str: def dep_makedepends_extra(dep: models.PackageDependency) -> str:
""" A function used to produce extra text for dependency display. """ """ A function used to produce extra text for dependency display. """
return "(make)" return "(make)"
def dep_checkdepends_extra(dep: PackageDependency) -> str: def dep_checkdepends_extra(dep: models.PackageDependency) -> str:
""" A function used to produce extra text for dependency display. """ """ A function used to produce extra text for dependency display. """
return "(check)" return "(check)"
def dep_optdepends_extra(dep: PackageDependency) -> str: def dep_optdepends_extra(dep: models.PackageDependency) -> str:
""" A function used to produce extra text for dependency display. """ """ A function used to produce extra text for dependency display. """
return "(optional)" return "(optional)"
@register_filter("dep_extra") @register_filter("dep_extra")
def dep_extra(dep: PackageDependency) -> str: def dep_extra(dep: models.PackageDependency) -> str:
""" Some dependency types have extra text added to their """ Some dependency types have extra text added to their
display. This function provides that output. However, it display. This function provides that output. However, it
**assumes** that the dep passed is bound to a valid one **assumes** that the dep passed is bound to a valid one
@ -53,7 +45,7 @@ def dep_extra(dep: PackageDependency) -> str:
@register_filter("dep_extra_desc") @register_filter("dep_extra_desc")
def dep_extra_desc(dep: PackageDependency) -> str: def dep_extra_desc(dep: models.PackageDependency) -> str:
extra = dep_extra(dep) extra = dep_extra(dep)
if not dep.DepDesc: if not dep.DepDesc:
return extra return extra
@ -63,30 +55,30 @@ def dep_extra_desc(dep: PackageDependency) -> str:
@register_filter("pkgname_link") @register_filter("pkgname_link")
def pkgname_link(pkgname: str) -> str: def pkgname_link(pkgname: str) -> str:
base = "/".join([OFFICIAL_BASE, "packages"]) base = "/".join([OFFICIAL_BASE, "packages"])
official = db.query(OfficialProvider).filter( official = db.query(models.OfficialProvider).filter(
OfficialProvider.Name == pkgname) models.OfficialProvider.Name == pkgname)
if official.scalar(): if official.scalar():
return f"{base}/?q={pkgname}" return f"{base}/?q={pkgname}"
return f"/packages/{pkgname}" return f"/packages/{pkgname}"
@register_filter("package_link") @register_filter("package_link")
def package_link(package: Package) -> str: def package_link(package: models.Package) -> str:
base = "/".join([OFFICIAL_BASE, "packages"]) base = "/".join([OFFICIAL_BASE, "packages"])
official = db.query(OfficialProvider).filter( official = db.query(models.OfficialProvider).filter(
OfficialProvider.Name == package.Name) models.OfficialProvider.Name == package.Name)
if official.scalar(): if official.scalar():
return f"{base}/?q={package.Name}" return f"{base}/?q={package.Name}"
return f"/packages/{package.Name}" return f"/packages/{package.Name}"
@register_filter("provides_list") @register_filter("provides_list")
def provides_list(package: Package, depname: str) -> list: def provides_list(package: models.Package, depname: str) -> list:
providers = db.query(Package).join( providers = db.query(models.Package).join(
PackageRelation).join(RelationType).filter( models.PackageRelation).join(models.RelationType).filter(
and_( and_(
PackageRelation.RelName == depname, models.PackageRelation.RelName == depname,
RelationType.ID == PROVIDES_ID models.RelationType.ID == PROVIDES_ID
) )
) )
@ -102,7 +94,9 @@ def provides_list(package: Package, depname: str) -> list:
return string return string
def get_pkg_or_base(name: str, cls: Union[Package, PackageBase] = PackageBase): def get_pkg_or_base(
name: str,
cls: Union[models.Package, models.PackageBase] = models.PackageBase):
""" Get a PackageBase instance by its name or raise a 404 if """ Get a PackageBase instance by its name or raise a 404 if
it can't be found in the database. it can't be found in the database.
@ -110,20 +104,21 @@ def get_pkg_or_base(name: str, cls: Union[Package, PackageBase] = PackageBase):
:raises HTTPException: With status code 404 if record doesn't exist :raises HTTPException: With status code 404 if record doesn't exist
:return: {Package,PackageBase} instance :return: {Package,PackageBase} instance
""" """
provider = db.query(OfficialProvider).filter( provider = db.query(models.OfficialProvider).filter(
OfficialProvider.Name == name).first() models.OfficialProvider.Name == name).first()
if provider: if provider:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
instance = db.query(cls).filter(cls.Name == name).first() instance = db.query(cls).filter(cls.Name == name).first()
if cls == PackageBase and not instance: if cls == models.PackageBase and not instance:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return instance return instance
def get_pkgbase_comment(pkgbase: PackageBase, id: int) -> PackageComment: def get_pkgbase_comment(
comment = pkgbase.comments.filter(PackageComment.ID == id).first() pkgbase: models.PackageBase, id: int) -> models.PackageComment:
comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
if not comment: if not comment:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
return comment return comment
@ -131,10 +126,11 @@ def get_pkgbase_comment(pkgbase: PackageBase, id: int) -> PackageComment:
@register_filter("out_of_date") @register_filter("out_of_date")
def out_of_date(packages: orm.Query) -> orm.Query: def out_of_date(packages: orm.Query) -> orm.Query:
return packages.filter(PackageBase.OutOfDateTS.isnot(None)) return packages.filter(models.PackageBase.OutOfDateTS.isnot(None))
def updated_packages(limit: int = 0, cache_ttl: int = 600) -> List[Package]: def updated_packages(limit: int = 0,
cache_ttl: int = 600) -> List[models.Package]:
""" Return a list of valid Package objects ordered by their """ Return a list of valid Package objects ordered by their
ModifiedTS column in descending order from cache, after setting ModifiedTS column in descending order from cache, after setting
the cache when no key yet exists. the cache when no key yet exists.
@ -149,10 +145,10 @@ def updated_packages(limit: int = 0, cache_ttl: int = 600) -> List[Package]:
# If we already have a cache, deserialize it and return. # If we already have a cache, deserialize it and return.
return orjson.loads(packages) return orjson.loads(packages)
query = db.query(Package).join(PackageBase).filter( query = db.query(models.Package).join(models.PackageBase).filter(
PackageBase.PackagerUID.isnot(None) models.PackageBase.PackagerUID.isnot(None)
).order_by( ).order_by(
PackageBase.ModifiedTS.desc() models.PackageBase.ModifiedTS.desc()
) )
if limit: if limit:
@ -178,7 +174,8 @@ def updated_packages(limit: int = 0, cache_ttl: int = 600) -> List[Package]:
return packages return packages
def query_voted(query: List[Package], user: User) -> Dict[int, bool]: def query_voted(query: List[models.Package],
user: models.User) -> Dict[int, bool]:
""" Produce a dictionary of package base ID keys to boolean values, """ Produce a dictionary of package base ID keys to boolean values,
which indicate whether or not the package base has a vote record which indicate whether or not the package base has a vote record
related to user. related to user.
@ -189,18 +186,19 @@ def query_voted(query: List[Package], user: User) -> Dict[int, bool]:
""" """
output = defaultdict(bool) output = defaultdict(bool)
query_set = {pkg.PackageBase.ID for pkg in query} query_set = {pkg.PackageBase.ID for pkg in query}
voted = db.query(PackageVote).join( voted = db.query(models.PackageVote).join(
PackageBase, models.PackageBase,
PackageBase.ID.in_(query_set) models.PackageBase.ID.in_(query_set)
).filter( ).filter(
PackageVote.UsersID == user.ID models.PackageVote.UsersID == user.ID
) )
for vote in voted: for vote in voted:
output[vote.PackageBase.ID] = True output[vote.PackageBase.ID] = True
return output return output
def query_notified(query: List[Package], user: User) -> Dict[int, bool]: def query_notified(query: List[models.Package],
user: models.User) -> Dict[int, bool]:
""" Produce a dictionary of package base ID keys to boolean values, """ Produce a dictionary of package base ID keys to boolean values,
which indicate whether or not the package base has a notification which indicate whether or not the package base has a notification
record related to user. record related to user.
@ -211,11 +209,11 @@ def query_notified(query: List[Package], user: User) -> Dict[int, bool]:
""" """
output = defaultdict(bool) output = defaultdict(bool)
query_set = {pkg.PackageBase.ID for pkg in query} query_set = {pkg.PackageBase.ID for pkg in query}
notified = db.query(PackageNotification).join( notified = db.query(models.PackageNotification).join(
PackageBase, models.PackageBase,
PackageBase.ID.in_(query_set) models.PackageBase.ID.in_(query_set)
).filter( ).filter(
PackageNotification.UserID == user.ID models.PackageNotification.UserID == user.ID
) )
for notify in notified: for notify in notified:
output[notify.PackageBase.ID] = True output[notify.PackageBase.ID] = True

View file

@ -11,17 +11,12 @@ from sqlalchemy import and_, func, or_
import aurweb.config import aurweb.config
from aurweb import db, l10n, time, util from aurweb import db, l10n, models, time, util
from aurweb.auth import account_type_required, auth_required from aurweb.auth import account_type_required, auth_required
from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token from aurweb.captcha import get_captcha_answer, get_captcha_salts, get_captcha_token
from aurweb.l10n import get_translator_for_request from aurweb.l10n import get_translator_for_request
from aurweb.models.accepted_term import AcceptedTerm from aurweb.models import account_type
from aurweb.models.account_type import (DEVELOPER, DEVELOPER_ID, TRUSTED_USER, TRUSTED_USER_AND_DEV, TRUSTED_USER_AND_DEV_ID, from aurweb.models.ssh_pub_key import get_fingerprint
TRUSTED_USER_ID, USER_ID, AccountType)
from aurweb.models.ban import Ban
from aurweb.models.ssh_pub_key import SSHPubKey, get_fingerprint
from aurweb.models.term import Term
from aurweb.models.user import User
from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification from aurweb.scripts.notify import ResetKeyNotification, WelcomeNotification
from aurweb.templates import make_context, make_variable_context, render_template from aurweb.templates import make_context, make_variable_context, render_template
@ -46,8 +41,8 @@ async def passreset_post(request: Request,
context = await make_variable_context(request, "Password Reset") context = await make_variable_context(request, "Password Reset")
# The user parameter being required, we can match against # The user parameter being required, we can match against
user = db.query(User, or_(User.Username == user, user = db.query(models.User, or_(models.User.Username == user,
User.Email == user)).first() models.User.Email == user)).first()
if not user: if not user:
context["errors"] = ["Invalid e-mail."] context["errors"] = ["Invalid e-mail."]
return render_template(request, "passreset.html", context, return render_template(request, "passreset.html", context,
@ -72,13 +67,13 @@ async def passreset_post(request: Request,
return render_template(request, "passreset.html", context, return render_template(request, "passreset.html", context,
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)
if len(password) < User.minimum_passwd_length(): if len(password) < models.User.minimum_passwd_length():
# Translate the error here, which simplifies error output # Translate the error here, which simplifies error output
# in the jinja2 template. # in the jinja2 template.
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
context["errors"] = [_( context["errors"] = [_(
"Your password must be at least %s characters.") % ( "Your password must be at least %s characters.") % (
str(User.minimum_passwd_length()))] str(models.User.minimum_passwd_length()))]
return render_template(request, "passreset.html", context, return render_template(request, "passreset.html", context,
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)
@ -95,7 +90,7 @@ async def passreset_post(request: Request,
status_code=HTTPStatus.SEE_OTHER) status_code=HTTPStatus.SEE_OTHER)
# If we got here, we continue with issuing a resetkey for the user. # If we got here, we continue with issuing a resetkey for the user.
resetkey = db.make_random_value(User, User.ResetKey) resetkey = db.make_random_value(models.User, models.User.ResetKey)
with db.begin(): with db.begin():
user.ResetKey = resetkey user.ResetKey = resetkey
@ -107,7 +102,7 @@ async def passreset_post(request: Request,
status_code=HTTPStatus.SEE_OTHER) status_code=HTTPStatus.SEE_OTHER)
def process_account_form(request: Request, user: User, args: dict): def process_account_form(request: Request, user: models.User, args: dict):
""" Process an account form. All fields are optional and only checks """ Process an account form. All fields are optional and only checks
requirements in the case they are present. requirements in the case they are present.
@ -129,11 +124,11 @@ def process_account_form(request: Request, user: User, args: dict):
_ = get_translator_for_request(request) _ = get_translator_for_request(request)
host = request.client.host host = request.client.host
ban = db.query(Ban, Ban.IPAddress == host).first() ban = db.query(models.Ban, models.Ban.IPAddress == host).first()
if ban: if ban:
return False, [ return False, [
"Account registration has been disabled for your " + "Account registration has been disabled for your "
"IP address, probably due to sustained spam attacks. " + "IP address, probably due to sustained spam attacks. "
"Sorry for the inconvenience." "Sorry for the inconvenience."
] ]
@ -181,12 +176,12 @@ def process_account_form(request: Request, user: User, args: dict):
timezone = args.get("TZ", None) timezone = args.get("TZ", None)
def username_exists(username): def username_exists(username):
return and_(User.ID != user.ID, return and_(models.User.ID != user.ID,
func.lower(User.Username) == username.lower()) func.lower(models.User.Username) == username.lower())
def email_exists(email): def email_exists(email):
return and_(User.ID != user.ID, return and_(models.User.ID != user.ID,
func.lower(User.Email) == email.lower()) func.lower(models.User.Email) == email.lower())
if not util.valid_email(email): if not util.valid_email(email):
return False, ["The email address is invalid."] return False, ["The email address is invalid."]
@ -203,13 +198,13 @@ def process_account_form(request: Request, user: User, args: dict):
return False, ["Language is not currently supported."] return False, ["Language is not currently supported."]
elif timezone and timezone not in time.SUPPORTED_TIMEZONES: elif timezone and timezone not in time.SUPPORTED_TIMEZONES:
return False, ["Timezone is not currently supported."] return False, ["Timezone is not currently supported."]
elif db.query(User, username_exists(username)).first(): elif db.query(models.User, username_exists(username)).first():
# If the username already exists... # If the username already exists...
return False, [ return False, [
_("The username, %s%s%s, is already in use.") % ( _("The username, %s%s%s, is already in use.") % (
"<strong>", username, "</strong>") "<strong>", username, "</strong>")
] ]
elif db.query(User, email_exists(email)).first(): elif db.query(models.User, email_exists(email)).first():
# If the email already exists... # If the email already exists...
return False, [ return False, [
_("The address, %s%s%s, is already in use.") % ( _("The address, %s%s%s, is already in use.") % (
@ -217,15 +212,16 @@ def process_account_form(request: Request, user: User, args: dict):
] ]
def ssh_fingerprint_exists(fingerprint): def ssh_fingerprint_exists(fingerprint):
return and_(SSHPubKey.UserID != user.ID, return and_(models.SSHPubKey.UserID != user.ID,
SSHPubKey.Fingerprint == fingerprint) models.SSHPubKey.Fingerprint == fingerprint)
if ssh_pubkey: if ssh_pubkey:
fingerprint = get_fingerprint(ssh_pubkey.strip().rstrip()) fingerprint = get_fingerprint(ssh_pubkey.strip().rstrip())
if fingerprint is None: if fingerprint is None:
return False, ["The SSH public key is invalid."] return False, ["The SSH public key is invalid."]
if db.query(SSHPubKey, ssh_fingerprint_exists(fingerprint)).first(): if db.query(models.SSHPubKey,
ssh_fingerprint_exists(fingerprint)).first():
return False, [ return False, [
_("The SSH public key, %s%s%s, is already in use.") % ( _("The SSH public key, %s%s%s, is already in use.") % (
"<strong>", fingerprint, "</strong>") "<strong>", fingerprint, "</strong>")
@ -246,7 +242,7 @@ def process_account_form(request: Request, user: User, args: dict):
def make_account_form_context(context: dict, def make_account_form_context(context: dict,
request: Request, request: Request,
user: User, user: models.User,
args: dict): args: dict):
""" Modify a FastAPI context and add attributes for the account form. """ Modify a FastAPI context and add attributes for the account form.
@ -382,20 +378,20 @@ async def account_register_post(request: Request,
# Create a user with no password with a resetkey, then send # Create a user with no password with a resetkey, then send
# an email off about it. # an email off about it.
resetkey = db.make_random_value(User, User.ResetKey) resetkey = db.make_random_value(models.User, models.User.ResetKey)
# By default, we grab the User account type to associate with. # By default, we grab the User account type to associate with.
account_type = db.query(AccountType, atype = db.query(models.AccountType,
AccountType.AccountType == "User").first() models.AccountType.AccountType == "User").first()
# Create a user given all parameters available. # Create a user given all parameters available.
with db.begin(): with db.begin():
user = db.create(User, Username=U, user = db.create(models.User, Username=U,
Email=E, HideEmail=H, BackupEmail=BE, Email=E, HideEmail=H, BackupEmail=BE,
RealName=R, Homepage=HP, IRCNick=I, PGPKey=K, RealName=R, Homepage=HP, IRCNick=I, PGPKey=K,
LangPreference=L, Timezone=TZ, CommentNotify=CN, LangPreference=L, Timezone=TZ, CommentNotify=CN,
UpdateNotify=UN, OwnershipNotify=ON, UpdateNotify=UN, OwnershipNotify=ON,
ResetKey=resetkey, AccountType=account_type) ResetKey=resetkey, AccountType=atype)
# If a PK was given and either one does not exist or the given # If a PK was given and either one does not exist or the given
# PK mismatches the existing user's SSHPubKey.PubKey. # PK mismatches the existing user's SSHPubKey.PubKey.
@ -408,9 +404,9 @@ async def account_register_post(request: Request,
pubkey = parts[0] + " " + parts[1] pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey) fingerprint = get_fingerprint(pubkey)
with db.begin(): with db.begin():
user.ssh_pub_key = SSHPubKey(UserID=user.ID, user.ssh_pub_key = models.SSHPubKey(UserID=user.ID,
PubKey=pubkey, PubKey=pubkey,
Fingerprint=fingerprint) Fingerprint=fingerprint)
# Send a reset key notification to the new user. # Send a reset key notification to the new user.
executor = db.ConnectionExecutor(db.get_engine().raw_connection()) executor = db.ConnectionExecutor(db.get_engine().raw_connection())
@ -435,7 +431,7 @@ def cannot_edit(request, user):
@auth_required(True, redirect="/account/{username}") @auth_required(True, redirect="/account/{username}")
async def account_edit(request: Request, async def account_edit(request: Request,
username: str): username: str):
user = db.query(User, User.Username == username).first() user = db.query(models.User, models.User.Username == username).first()
response = cannot_edit(request, user) response = cannot_edit(request, user)
if response: if response:
return response return response
@ -473,7 +469,8 @@ async def account_edit_post(request: Request,
passwd: str = Form(default=str())): passwd: str = Form(default=str())):
from aurweb.db import session from aurweb.db import session
user = session.query(User).filter(User.Username == username).first() user = session.query(models.User).filter(
models.User.Username == username).first()
response = cannot_edit(request, user) response = cannot_edit(request, user)
if response: if response:
return response return response
@ -538,9 +535,9 @@ async def account_edit_post(request: Request,
fingerprint = get_fingerprint(pubkey) fingerprint = get_fingerprint(pubkey)
if not user.ssh_pub_key: if not user.ssh_pub_key:
# No public key exists, create one. # No public key exists, create one.
user.ssh_pub_key = SSHPubKey(UserID=user.ID, user.ssh_pub_key = models.SSHPubKey(UserID=user.ID,
PubKey=pubkey, PubKey=pubkey,
Fingerprint=fingerprint) Fingerprint=fingerprint)
elif user.ssh_pub_key.PubKey != pubkey: elif user.ssh_pub_key.PubKey != pubkey:
# A public key already exists, update it. # A public key already exists, update it.
user.ssh_pub_key.PubKey = pubkey user.ssh_pub_key.PubKey = pubkey
@ -584,7 +581,7 @@ async def account(request: Request, username: str):
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
context = await make_variable_context(request, _("Account") + username) context = await make_variable_context(request, _("Account") + username)
user = db.query(User, User.Username == username).first() user = db.query(models.User, models.User.Username == username).first()
if not user: if not user:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
@ -595,7 +592,9 @@ async def account(request: Request, username: str):
@router.get("/accounts/") @router.get("/accounts/")
@auth_required(True, redirect="/accounts/") @auth_required(True, redirect="/accounts/")
@account_type_required({TRUSTED_USER, DEVELOPER, TRUSTED_USER_AND_DEV}) @account_type_required({account_type.TRUSTED_USER,
account_type.DEVELOPER,
account_type.TRUSTED_USER_AND_DEV})
async def accounts(request: Request): async def accounts(request: Request):
context = make_context(request, "Accounts") context = make_context(request, "Accounts")
return render_template(request, "account/search.html", context) return render_template(request, "account/search.html", context)
@ -603,7 +602,9 @@ async def accounts(request: Request):
@router.post("/accounts/") @router.post("/accounts/")
@auth_required(True, redirect="/accounts/") @auth_required(True, redirect="/accounts/")
@account_type_required({TRUSTED_USER, DEVELOPER, TRUSTED_USER_AND_DEV}) @account_type_required({account_type.TRUSTED_USER,
account_type.DEVELOPER,
account_type.TRUSTED_USER_AND_DEV})
async def accounts_post(request: Request, async def accounts_post(request: Request,
O: int = Form(default=0), # Offset O: int = Form(default=0), # Offset
SB: str = Form(default=str()), # Search By SB: str = Form(default=str()), # Search By
@ -626,44 +627,44 @@ async def accounts_post(request: Request,
# Setup order by criteria based on SB. # Setup order by criteria based on SB.
order_by_columns = { order_by_columns = {
"t": (AccountType.ID.asc(), User.Username.asc()), "t": (models.AccountType.ID.asc(), models.User.Username.asc()),
"r": (User.RealName.asc(), AccountType.ID.asc()), "r": (models.User.RealName.asc(), models.AccountType.ID.asc()),
"i": (User.IRCNick.asc(), AccountType.ID.asc()), "i": (models.User.IRCNick.asc(), models.AccountType.ID.asc()),
} }
default_order = (User.Username.asc(), AccountType.ID.asc()) default_order = (models.User.Username.asc(), models.AccountType.ID.asc())
order_by = order_by_columns.get(SB, default_order) order_by = order_by_columns.get(SB, default_order)
# Convert parameter T to an AccountType ID. # Convert parameter T to an AccountType ID.
account_types = { account_types = {
"u": USER_ID, "u": account_type.USER_ID,
"t": TRUSTED_USER_ID, "t": account_type.TRUSTED_USER_ID,
"d": DEVELOPER_ID, "d": account_type.DEVELOPER_ID,
"td": TRUSTED_USER_AND_DEV_ID "td": account_type.TRUSTED_USER_AND_DEV_ID
} }
account_type_id = account_types.get(T, None) account_type_id = account_types.get(T, None)
# Get a query handle to users, populate the total user # Get a query handle to users, populate the total user
# count into a jinja2 context variable. # count into a jinja2 context variable.
query = db.query(User).join(AccountType) query = db.query(models.User).join(models.AccountType)
context["total_users"] = query.count() context["total_users"] = query.count()
# Populate this list with any additional statements to # Populate this list with any additional statements to
# be ANDed together. # be ANDed together.
statements = [] statements = []
if account_type_id is not None: if account_type_id is not None:
statements.append(AccountType.ID == account_type_id) statements.append(models.AccountType.ID == account_type_id)
if U: if U:
statements.append(User.Username.like(f"%{U}%")) statements.append(models.User.Username.like(f"%{U}%"))
if S: if S:
statements.append(User.Suspended == S) statements.append(models.User.Suspended == S)
if E: if E:
statements.append(User.Email.like(f"%{E}%")) statements.append(models.User.Email.like(f"%{E}%"))
if R: if R:
statements.append(User.RealName.like(f"%{R}%")) statements.append(models.User.RealName.like(f"%{R}%"))
if I: if I:
statements.append(User.IRCNick.like(f"%{I}%")) statements.append(models.User.IRCNick.like(f"%{I}%"))
if K: if K:
statements.append(User.PGPKey.like(f"%{K}%")) statements.append(models.User.PGPKey.like(f"%{K}%"))
# Filter the query by combining all statements added above into # Filter the query by combining all statements added above into
# an AND statement, unless there's just one statement, which # an AND statement, unless there's just one statement, which
@ -692,12 +693,12 @@ def render_terms_of_service(request: Request,
async def terms_of_service(request: Request): async def terms_of_service(request: Request):
# Query the database for terms that were previously accepted, # Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted. # but now have a bumped Revision that needs to be accepted.
diffs = db.query(Term).join(AcceptedTerm).filter( diffs = db.query(models.Term).join(models.AcceptedTerm).filter(
AcceptedTerm.Revision < Term.Revision).all() models.AcceptedTerm.Revision < models.Term.Revision).all()
# Query the database for any terms that have not yet been accepted. # Query the database for any terms that have not yet been accepted.
unaccepted = db.query(Term).filter( unaccepted = db.query(models.Term).filter(
~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all() ~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
# Translate the 'Terms of Service' part of our page title. # Translate the 'Terms of Service' part of our page title.
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
@ -714,12 +715,12 @@ async def terms_of_service_post(request: Request,
accept: bool = Form(default=False)): accept: bool = Form(default=False)):
# Query the database for terms that were previously accepted, # Query the database for terms that were previously accepted,
# but now have a bumped Revision that needs to be accepted. # but now have a bumped Revision that needs to be accepted.
diffs = db.query(Term).join(AcceptedTerm).filter( diffs = db.query(models.Term).join(models.AcceptedTerm).filter(
AcceptedTerm.Revision < Term.Revision).all() models.AcceptedTerm.Revision < models.Term.Revision).all()
# Query the database for any terms that have not yet been accepted. # Query the database for any terms that have not yet been accepted.
unaccepted = db.query(Term).filter( unaccepted = db.query(models.Term).filter(
~Term.ID.in_(db.query(AcceptedTerm.TermsID))).all() ~models.Term.ID.in_(db.query(models.AcceptedTerm.TermsID))).all()
if not accept: if not accept:
# Translate the 'Terms of Service' part of our page title. # Translate the 'Terms of Service' part of our page title.
@ -737,12 +738,12 @@ async def terms_of_service_post(request: Request,
# and update its Revision to the term's current Revision. # and update its Revision to the term's current Revision.
for term in diffs: for term in diffs:
accepted_term = request.user.accepted_terms.filter( accepted_term = request.user.accepted_terms.filter(
AcceptedTerm.TermsID == term.ID).first() models.AcceptedTerm.TermsID == term.ID).first()
accepted_term.Revision = term.Revision accepted_term.Revision = term.Revision
# For each term that was never accepted, accept it! # For each term that was never accepted, accept it!
for term in unaccepted: for term in unaccepted:
db.create(AcceptedTerm, User=request.user, db.create(models.AcceptedTerm, User=request.user,
Term=term, Revision=term.Revision) Term=term, Revision=term.Revision)
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)

View file

@ -8,7 +8,7 @@ import aurweb.config
from aurweb import util from aurweb import util
from aurweb.auth import auth_required from aurweb.auth import auth_required
from aurweb.models.user import User from aurweb.models import User
from aurweb.templates import make_variable_context, render_template from aurweb.templates import make_variable_context, render_template
router = APIRouter() router = APIRouter()

View file

@ -11,14 +11,10 @@ from sqlalchemy import and_, case, or_
import aurweb.config import aurweb.config
import aurweb.models.package_request import aurweb.models.package_request
from aurweb import db, util from aurweb import db, models, util
from aurweb.cache import db_count_cache from aurweb.cache import db_count_cache
from aurweb.models.account_type import TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID from aurweb.models.account_type import TRUSTED_USER_AND_DEV_ID, TRUSTED_USER_ID
from aurweb.models.package import Package from aurweb.models.package_request import PENDING_ID
from aurweb.models.package_base import PackageBase
from aurweb.models.package_comaintainer import PackageComaintainer
from aurweb.models.package_request import PENDING_ID, PackageRequest
from aurweb.models.user import User
from aurweb.packages.util import query_notified, query_voted, updated_packages from aurweb.packages.util import query_notified, query_voted, updated_packages
from aurweb.templates import make_context, render_template from aurweb.templates import make_context, render_template
@ -69,31 +65,31 @@ async def index(request: Request):
context = make_context(request, "Home") context = make_context(request, "Home")
context['ssh_fingerprints'] = util.get_ssh_fingerprints() context['ssh_fingerprints'] = util.get_ssh_fingerprints()
bases = db.query(PackageBase) bases = db.query(models.PackageBase)
redis = aurweb.redis.redis_connection() redis = aurweb.redis.redis_connection()
stats_expire = 300 # Five minutes. stats_expire = 300 # Five minutes.
updates_expire = 600 # Ten minutes. updates_expire = 600 # Ten minutes.
# Package statistics. # Package statistics.
query = bases.filter(PackageBase.PackagerUID.isnot(None)) query = bases.filter(models.PackageBase.PackagerUID.isnot(None))
context["package_count"] = await db_count_cache( context["package_count"] = await db_count_cache(
redis, "package_count", query, expire=stats_expire) redis, "package_count", query, expire=stats_expire)
query = bases.filter( query = bases.filter(
and_(PackageBase.MaintainerUID.is_(None), and_(models.PackageBase.MaintainerUID.is_(None),
PackageBase.PackagerUID.isnot(None)) models.PackageBase.PackagerUID.isnot(None))
) )
context["orphan_count"] = await db_count_cache( context["orphan_count"] = await db_count_cache(
redis, "orphan_count", query, expire=stats_expire) redis, "orphan_count", query, expire=stats_expire)
query = db.query(User) query = db.query(models.User)
context["user_count"] = await db_count_cache( context["user_count"] = await db_count_cache(
redis, "user_count", query, expire=stats_expire) redis, "user_count", query, expire=stats_expire)
query = query.filter( query = query.filter(
or_(User.AccountTypeID == TRUSTED_USER_ID, or_(models.User.AccountTypeID == TRUSTED_USER_ID,
User.AccountTypeID == TRUSTED_USER_AND_DEV_ID)) models.User.AccountTypeID == TRUSTED_USER_AND_DEV_ID))
context["trusted_user_count"] = await db_count_cache( context["trusted_user_count"] = await db_count_cache(
redis, "trusted_user_count", query, expire=stats_expire) redis, "trusted_user_count", query, expire=stats_expire)
@ -105,29 +101,29 @@ async def index(request: Request):
one_hour = 3600 one_hour = 3600
updated = bases.filter( updated = bases.filter(
and_(PackageBase.ModifiedTS - PackageBase.SubmittedTS >= one_hour, and_(models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS >= one_hour,
PackageBase.PackagerUID.isnot(None)) models.PackageBase.PackagerUID.isnot(None))
) )
query = bases.filter( query = bases.filter(
and_(PackageBase.SubmittedTS >= seven_days_ago, and_(models.PackageBase.SubmittedTS >= seven_days_ago,
PackageBase.PackagerUID.isnot(None)) models.PackageBase.PackagerUID.isnot(None))
) )
context["seven_days_old_added"] = await db_count_cache( context["seven_days_old_added"] = await db_count_cache(
redis, "seven_days_old_added", query, expire=stats_expire) redis, "seven_days_old_added", query, expire=stats_expire)
query = updated.filter(PackageBase.ModifiedTS >= seven_days_ago) query = updated.filter(models.PackageBase.ModifiedTS >= seven_days_ago)
context["seven_days_old_updated"] = await db_count_cache( context["seven_days_old_updated"] = await db_count_cache(
redis, "seven_days_old_updated", query, expire=stats_expire) redis, "seven_days_old_updated", query, expire=stats_expire)
year = seven_days * 52 # Fifty two weeks worth: one year. year = seven_days * 52 # Fifty two weeks worth: one year.
year_ago = now - year year_ago = now - year
query = updated.filter(PackageBase.ModifiedTS >= year_ago) query = updated.filter(models.PackageBase.ModifiedTS >= year_ago)
context["year_old_updated"] = await db_count_cache( context["year_old_updated"] = await db_count_cache(
redis, "year_old_updated", query, expire=stats_expire) redis, "year_old_updated", query, expire=stats_expire)
query = bases.filter( query = bases.filter(
PackageBase.ModifiedTS - PackageBase.SubmittedTS < 3600) models.PackageBase.ModifiedTS - models.PackageBase.SubmittedTS < 3600)
context["never_updated"] = await db_count_cache( context["never_updated"] = await db_count_cache(
redis, "never_updated", query, expire=stats_expire) redis, "never_updated", query, expire=stats_expire)
@ -137,19 +133,19 @@ async def index(request: Request):
if request.user.is_authenticated(): if request.user.is_authenticated():
# Authenticated users get a few extra pieces of data for # Authenticated users get a few extra pieces of data for
# the dashboard display. # the dashboard display.
packages = db.query(Package).join(PackageBase) packages = db.query(models.Package).join(models.PackageBase)
maintained = packages.join( maintained = packages.join(
User, PackageBase.MaintainerUID == User.ID models.User, models.PackageBase.MaintainerUID == models.User.ID
).filter( ).filter(
PackageBase.MaintainerUID == request.user.ID models.PackageBase.MaintainerUID == request.user.ID
) )
# Packages maintained by the user that have been flagged. # Packages maintained by the user that have been flagged.
context["flagged_packages"] = maintained.filter( context["flagged_packages"] = maintained.filter(
PackageBase.OutOfDateTS.isnot(None) models.PackageBase.OutOfDateTS.isnot(None)
).order_by( ).order_by(
PackageBase.ModifiedTS.desc(), Package.Name.asc() models.PackageBase.ModifiedTS.desc(), models.Package.Name.asc()
).limit(50).all() ).limit(50).all()
# Flagged packages that request.user has voted for. # Flagged packages that request.user has voted for.
@ -165,17 +161,18 @@ async def index(request: Request):
# Package requests created by request.user. # Package requests created by request.user.
context["package_requests"] = request.user.package_requests.filter( context["package_requests"] = request.user.package_requests.filter(
PackageRequest.RequestTS >= start models.PackageRequest.RequestTS >= start
).order_by( ).order_by(
# Order primarily by the Status column being PENDING_ID, # Order primarily by the Status column being PENDING_ID,
# and secondarily by RequestTS; both in descending order. # and secondarily by RequestTS; both in descending order.
case([(PackageRequest.Status == PENDING_ID, 1)], else_=0).desc(), case([(models.PackageRequest.Status == PENDING_ID, 1)],
PackageRequest.RequestTS.desc() else_=0).desc(),
models.PackageRequest.RequestTS.desc()
).limit(50).all() ).limit(50).all()
# Packages that the request user maintains or comaintains. # Packages that the request user maintains or comaintains.
context["packages"] = maintained.order_by( context["packages"] = maintained.order_by(
PackageBase.ModifiedTS.desc(), Package.Name.desc() models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc()
).limit(50).all() ).limit(50).all()
# Packages that request.user has voted for. # Packages that request.user has voted for.
@ -188,11 +185,11 @@ async def index(request: Request):
# Any packages that the request user comaintains. # Any packages that the request user comaintains.
context["comaintained"] = packages.join( context["comaintained"] = packages.join(
PackageComaintainer models.PackageComaintainer
).filter( ).filter(
PackageComaintainer.UsersID == request.user.ID models.PackageComaintainer.UsersID == request.user.ID
).order_by( ).order_by(
PackageBase.ModifiedTS.desc(), Package.Name.desc() models.PackageBase.ModifiedTS.desc(), models.Package.Name.desc()
).limit(50).all() ).limit(50).all()
# Comaintained packages that request.user has voted for. # Comaintained packages that request.user has voted for.

View file

@ -7,27 +7,13 @@ from fastapi.responses import JSONResponse, RedirectResponse
from sqlalchemy import and_, case from sqlalchemy import and_, case
import aurweb.filters import aurweb.filters
import aurweb.models.package_comment
import aurweb.models.package_keyword
import aurweb.packages.util import aurweb.packages.util
from aurweb import db, defaults, l10n, util from aurweb import db, defaults, l10n, models, util
from aurweb.auth import auth_required from aurweb.auth import auth_required
from aurweb.models.license import License from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID
from aurweb.models.package import Package
from aurweb.models.package_base import PackageBase
from aurweb.models.package_comaintainer import PackageComaintainer
from aurweb.models.package_comment import PackageComment
from aurweb.models.package_dependency import PackageDependency
from aurweb.models.package_license import PackageLicense
from aurweb.models.package_notification import PackageNotification
from aurweb.models.package_relation import PackageRelation
from aurweb.models.package_request import ACCEPTED_ID, PENDING_ID, REJECTED_ID, PackageRequest
from aurweb.models.package_source import PackageSource
from aurweb.models.package_vote import PackageVote
from aurweb.models.relation_type import CONFLICTS_ID from aurweb.models.relation_type import CONFLICTS_ID
from aurweb.models.request_type import DELETION_ID, RequestType from aurweb.models.request_type import DELETION_ID
from aurweb.models.user import User
from aurweb.packages.search import PackageSearch from aurweb.packages.search import PackageSearch
from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, query_notified, query_voted from aurweb.packages.util import get_pkg_or_base, get_pkgbase_comment, query_notified, query_voted
from aurweb.scripts import notify, popupdate from aurweb.scripts import notify, popupdate
@ -75,9 +61,9 @@ async def packages_get(request: Request, context: Dict[str, Any]):
# do **not** have OutOfDateTS. # do **not** have OutOfDateTS.
criteria = None criteria = None
if flagged == "on": if flagged == "on":
criteria = PackageBase.OutOfDateTS.isnot criteria = models.PackageBase.OutOfDateTS.isnot
else: else:
criteria = PackageBase.OutOfDateTS.is_ criteria = models.PackageBase.OutOfDateTS.is_
# Apply the flag criteria to our PackageSearch.query. # Apply the flag criteria to our PackageSearch.query.
search.query = search.query.filter(criteria(None)) search.query = search.query.filter(criteria(None))
@ -86,7 +72,8 @@ async def packages_get(request: Request, context: Dict[str, Any]):
if submit == "Orphans": if submit == "Orphans":
# If the user clicked the "Orphans" button, we only want # If the user clicked the "Orphans" button, we only want
# orphaned packages. # orphaned packages.
search.query = search.query.filter(PackageBase.MaintainerUID.is_(None)) search.query = search.query.filter(
models.PackageBase.MaintainerUID.is_(None))
# Apply user-specified specified sort column and ordering. # Apply user-specified specified sort column and ordering.
search.sort_by(sort_by, sort_order) search.sort_by(sort_by, sort_order)
@ -116,19 +103,19 @@ async def packages(request: Request) -> Response:
return await packages_get(request, context) return await packages_get(request, context)
def create_request_if_missing(requests: List[PackageRequest], def create_request_if_missing(requests: List[models.PackageRequest],
reqtype: RequestType, reqtype: models.RequestType,
user: User, user: models.User,
package: Package): package: models.Package):
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
pkgreq = db.query(PackageRequest).filter( pkgreq = db.query(models.PackageRequest).filter(
PackageRequest.PackageBaseName == package.PackageBase.Name models.PackageRequest.PackageBaseName == package.PackageBase.Name
).first() ).first()
if not pkgreq: if not pkgreq:
# No PackageRequest existed. Create one. # No PackageRequest existed. Create one.
comments = "Automatically generated by aurweb." comments = "Automatically generated by aurweb."
closure_comment = "Deleted by aurweb." closure_comment = "Deleted by aurweb."
pkgreq = db.create(PackageRequest, pkgreq = db.create(models.PackageRequest,
RequestType=reqtype, RequestType=reqtype,
PackageBase=package.PackageBase, PackageBase=package.PackageBase,
PackageBaseName=package.PackageBase.Name, PackageBaseName=package.PackageBase.Name,
@ -141,8 +128,7 @@ def create_request_if_missing(requests: List[PackageRequest],
requests.append(pkgreq) requests.append(pkgreq)
def delete_package(deleter: User, def delete_package(deleter: models.User, package: models.Package):
package: Package):
notifications = [] notifications = []
requests = [] requests = []
bases_to_delete = [] bases_to_delete = []
@ -150,8 +136,8 @@ def delete_package(deleter: User,
conn = db.ConnectionExecutor(db.get_engine().raw_connection()) conn = db.ConnectionExecutor(db.get_engine().raw_connection())
# In all cases, though, just delete the Package in question. # In all cases, though, just delete the Package in question.
if package.PackageBase.packages.count() == 1: if package.PackageBase.packages.count() == 1:
reqtype = db.query(RequestType).filter( reqtype = db.query(models.RequestType).filter(
RequestType.ID == DELETION_ID models.RequestType.ID == DELETION_ID
).first() ).first()
with db.begin(): with db.begin():
@ -187,7 +173,7 @@ def delete_package(deleter: User,
async def make_single_context(request: Request, async def make_single_context(request: Request,
pkgbase: PackageBase) -> Dict[str, Any]: pkgbase: models.PackageBase) -> Dict[str, Any]:
""" Make a basic context for package or pkgbase. """ Make a basic context for package or pkgbase.
:param request: FastAPI request :param request: FastAPI request
@ -203,11 +189,11 @@ async def make_single_context(request: Request,
context["packages_count"] = pkgbase.packages.count() context["packages_count"] = pkgbase.packages.count()
context["keywords"] = pkgbase.keywords context["keywords"] = pkgbase.keywords
context["comments"] = pkgbase.comments.order_by( context["comments"] = pkgbase.comments.order_by(
PackageComment.CommentTS.desc() models.PackageComment.CommentTS.desc()
) )
context["pinned_comments"] = pkgbase.comments.filter( context["pinned_comments"] = pkgbase.comments.filter(
PackageComment.PinnedTS != 0 models.PackageComment.PinnedTS != 0
).order_by(PackageComment.CommentTS.desc()) ).order_by(models.PackageComment.CommentTS.desc())
context["is_maintainer"] = (request.user.is_authenticated() context["is_maintainer"] = (request.user.is_authenticated()
and request.user.ID == pkgbase.MaintainerUID) and request.user.ID == pkgbase.MaintainerUID)
@ -216,10 +202,10 @@ async def make_single_context(request: Request,
context["out_of_date"] = bool(pkgbase.OutOfDateTS) context["out_of_date"] = bool(pkgbase.OutOfDateTS)
context["voted"] = request.user.package_votes.filter( context["voted"] = request.user.package_votes.filter(
PackageVote.PackageBaseID == pkgbase.ID).scalar() models.PackageVote.PackageBaseID == pkgbase.ID).scalar()
context["requests"] = pkgbase.requests.filter( context["requests"] = pkgbase.requests.filter(
PackageRequest.ClosedTS.is_(None) models.PackageRequest.ClosedTS.is_(None)
).count() ).count()
return context return context
@ -228,8 +214,8 @@ async def make_single_context(request: Request,
@router.get("/packages/{name}") @router.get("/packages/{name}")
async def package(request: Request, name: str) -> Response: async def package(request: Request, name: str) -> Response:
# Get the Package. # Get the Package.
pkg = get_pkg_or_base(name, Package) pkg = get_pkg_or_base(name, models.Package)
pkgbase = (get_pkg_or_base(name, PackageBase) pkgbase = (get_pkg_or_base(name, models.PackageBase)
if not pkg else pkg.PackageBase) if not pkg else pkg.PackageBase)
# Add our base information. # Add our base information.
@ -237,28 +223,32 @@ async def package(request: Request, name: str) -> Response:
context["package"] = pkg context["package"] = pkg
# Package sources. # Package sources.
context["sources"] = db.query(PackageSource).join(Package).join( context["sources"] = db.query(models.PackageSource).join(
PackageBase).filter(PackageBase.ID == pkgbase.ID) models.Package).join(models.PackageBase).filter(
models.PackageBase.ID == pkgbase.ID)
# Package dependencies. # Package dependencies.
dependencies = db.query(PackageDependency).join(Package).join( dependencies = db.query(models.PackageDependency).join(
PackageBase).filter(PackageBase.ID == pkgbase.ID) models.Package).join(models.PackageBase).filter(
models.PackageBase.ID == pkgbase.ID)
context["dependencies"] = dependencies context["dependencies"] = dependencies
# Package requirements (other packages depend on this one). # Package requirements (other packages depend on this one).
required_by = db.query(PackageDependency).join(Package).filter( required_by = db.query(models.PackageDependency).join(
PackageDependency.DepName == pkgbase.Name).order_by( models.Package).filter(
Package.Name.asc()) models.PackageDependency.DepName == pkgbase.Name).order_by(
models.Package.Name.asc())
context["required_by"] = required_by context["required_by"] = required_by
licenses = db.query(License).join(PackageLicense).join(Package).join( licenses = db.query(models.License).join(models.PackageLicense).join(
PackageBase).filter(PackageBase.ID == pkgbase.ID) models.Package).join(models.PackageBase).filter(
models.PackageBase.ID == pkgbase.ID)
context["licenses"] = licenses context["licenses"] = licenses
conflicts = db.query(PackageRelation).join(Package).join( conflicts = db.query(models.PackageRelation).join(models.Package).join(
PackageBase).filter( models.PackageBase).filter(
and_(PackageRelation.RelTypeID == CONFLICTS_ID, and_(models.PackageRelation.RelTypeID == CONFLICTS_ID,
PackageBase.ID == pkgbase.ID) models.PackageBase.ID == pkgbase.ID)
) )
context["conflicts"] = conflicts context["conflicts"] = conflicts
@ -268,7 +258,7 @@ async def package(request: Request, name: str) -> Response:
@router.get("/pkgbase/{name}") @router.get("/pkgbase/{name}")
async def package_base(request: Request, name: str) -> Response: async def package_base(request: Request, name: str) -> Response:
# Get the PackageBase. # Get the PackageBase.
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
# If this is not a split package, redirect to /packages/{name}. # If this is not a split package, redirect to /packages/{name}.
if pkgbase.packages.count() == 1: if pkgbase.packages.count() == 1:
@ -285,7 +275,7 @@ async def package_base(request: Request, name: str) -> Response:
@router.get("/pkgbase/{name}/voters") @router.get("/pkgbase/{name}/voters")
async def package_base_voters(request: Request, name: str) -> Response: async def package_base_voters(request: Request, name: str) -> Response:
# Get the PackageBase. # Get the PackageBase.
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
context = make_context(request, "Voters") context = make_context(request, "Voters")
context["pkgbase"] = pkgbase context["pkgbase"] = pkgbase
return render_template(request, "pkgbase/voters.html", context) return render_template(request, "pkgbase/voters.html", context)
@ -298,7 +288,7 @@ async def pkgbase_comments_post(
comment: str = Form(default=str()), comment: str = Form(default=str()),
enable_notifications: bool = Form(default=False)): enable_notifications: bool = Form(default=False)):
""" Add a new comment. """ """ Add a new comment. """
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
if not comment: if not comment:
raise HTTPException(status_code=HTTPStatus.BAD_REQUEST) raise HTTPException(status_code=HTTPStatus.BAD_REQUEST)
@ -307,13 +297,13 @@ async def pkgbase_comments_post(
# update the db record. # update the db record.
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
with db.begin(): with db.begin():
comment = db.create(PackageComment, User=request.user, comment = db.create(models.PackageComment, User=request.user,
PackageBase=pkgbase, PackageBase=pkgbase,
Comments=comment, RenderedComment=str(), Comments=comment, RenderedComment=str(),
CommentTS=now) CommentTS=now)
if enable_notifications and not request.user.notified(pkgbase): if enable_notifications and not request.user.notified(pkgbase):
db.create(PackageNotification, db.create(models.PackageNotification,
User=request.user, User=request.user,
PackageBase=pkgbase) PackageBase=pkgbase)
update_comment_render(comment.ID) update_comment_render(comment.ID)
@ -327,8 +317,8 @@ async def pkgbase_comments_post(
@auth_required(True, login=False) @auth_required(True, login=False)
async def pkgbase_comment_form(request: Request, name: str, id: int): async def pkgbase_comment_form(request: Request, name: str, id: int):
""" Produce a comment form for comment {id}. """ """ Produce a comment form for comment {id}. """
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
comment = pkgbase.comments.filter(PackageComment.ID == id).first() comment = pkgbase.comments.filter(models.PackageComment.ID == id).first()
if not comment: if not comment:
return JSONResponse({}, status_code=HTTPStatus.NOT_FOUND) return JSONResponse({}, status_code=HTTPStatus.NOT_FOUND)
@ -349,7 +339,7 @@ async def pkgbase_comment_post(
request: Request, name: str, id: int, request: Request, name: str, id: int,
comment: str = Form(default=str()), comment: str = Form(default=str()),
enable_notifications: bool = Form(default=False)): enable_notifications: bool = Form(default=False)):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
db_comment = get_pkgbase_comment(pkgbase, id) db_comment = get_pkgbase_comment(pkgbase, id)
if not comment: if not comment:
@ -365,10 +355,10 @@ async def pkgbase_comment_post(
db_comment.EditedTS = now db_comment.EditedTS = now
db_notif = request.user.notifications.filter( db_notif = request.user.notifications.filter(
PackageNotification.PackageBaseID == pkgbase.ID models.PackageNotification.PackageBaseID == pkgbase.ID
).first() ).first()
if enable_notifications and not db_notif: if enable_notifications and not db_notif:
db.create(PackageNotification, db.create(models.PackageNotification,
User=request.user, User=request.user,
PackageBase=pkgbase) PackageBase=pkgbase)
update_comment_render(db_comment.ID) update_comment_render(db_comment.ID)
@ -381,7 +371,7 @@ async def pkgbase_comment_post(
@router.post("/pkgbase/{name}/comments/{id}/delete") @router.post("/pkgbase/{name}/comments/{id}/delete")
@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/delete") @auth_required(True, redirect="/pkgbase/{name}/comments/{id}/delete")
async def pkgbase_comment_delete(request: Request, name: str, id: int): async def pkgbase_comment_delete(request: Request, name: str, id: int):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
comment = get_pkgbase_comment(pkgbase, id) comment = get_pkgbase_comment(pkgbase, id)
authorized = request.user.has_credential("CRED_COMMENT_DELETE", authorized = request.user.has_credential("CRED_COMMENT_DELETE",
@ -404,7 +394,7 @@ async def pkgbase_comment_delete(request: Request, name: str, id: int):
@router.post("/pkgbase/{name}/comments/{id}/undelete") @router.post("/pkgbase/{name}/comments/{id}/undelete")
@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/undelete") @auth_required(True, redirect="/pkgbase/{name}/comments/{id}/undelete")
async def pkgbase_comment_undelete(request: Request, name: str, id: int): async def pkgbase_comment_undelete(request: Request, name: str, id: int):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
comment = get_pkgbase_comment(pkgbase, id) comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential("CRED_COMMENT_UNDELETE", has_cred = request.user.has_credential("CRED_COMMENT_UNDELETE",
@ -426,7 +416,7 @@ async def pkgbase_comment_undelete(request: Request, name: str, id: int):
@router.post("/pkgbase/{name}/comments/{id}/pin") @router.post("/pkgbase/{name}/comments/{id}/pin")
@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/pin") @auth_required(True, redirect="/pkgbase/{name}/comments/{id}/pin")
async def pkgbase_comment_pin(request: Request, name: str, id: int): async def pkgbase_comment_pin(request: Request, name: str, id: int):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
comment = get_pkgbase_comment(pkgbase, id) comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential("CRED_COMMENT_PIN", has_cred = request.user.has_credential("CRED_COMMENT_PIN",
@ -448,7 +438,7 @@ async def pkgbase_comment_pin(request: Request, name: str, id: int):
@router.post("/pkgbase/{name}/comments/{id}/unpin") @router.post("/pkgbase/{name}/comments/{id}/unpin")
@auth_required(True, redirect="/pkgbase/{name}/comments/{id}/unpin") @auth_required(True, redirect="/pkgbase/{name}/comments/{id}/unpin")
async def pkgbase_comment_unpin(request: Request, name: str, id: int): async def pkgbase_comment_unpin(request: Request, name: str, id: int):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
comment = get_pkgbase_comment(pkgbase, id) comment = get_pkgbase_comment(pkgbase, id)
has_cred = request.user.has_credential("CRED_COMMENT_PIN", has_cred = request.user.has_credential("CRED_COMMENT_PIN",
@ -470,7 +460,7 @@ async def pkgbase_comment_unpin(request: Request, name: str, id: int):
@auth_required(True, redirect="/pkgbase/{name}/comaintainers") @auth_required(True, redirect="/pkgbase/{name}/comaintainers")
async def package_base_comaintainers(request: Request, name: str) -> Response: async def package_base_comaintainers(request: Request, name: str) -> Response:
# Get the PackageBase. # Get the PackageBase.
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
# Unauthorized users (Non-TU/Dev and not the pkgbase maintainer) # Unauthorized users (Non-TU/Dev and not the pkgbase maintainer)
# get redirected to the package base's page. # get redirected to the package base's page.
@ -498,8 +488,8 @@ def remove_users(pkgbase, usernames):
for username in usernames: for username in usernames:
# We know that the users we passed here are in the DB. # We know that the users we passed here are in the DB.
# No need to check for their existence. # No need to check for their existence.
comaintainer = pkgbase.comaintainers.join(User).filter( comaintainer = pkgbase.comaintainers.join(models.User).filter(
User.Username == username models.User.Username == username
).first() ).first()
notifications.append( notifications.append(
notify.ComaintainerRemoveNotification( notify.ComaintainerRemoveNotification(
@ -519,7 +509,7 @@ async def package_base_comaintainers_post(
request: Request, name: str, request: Request, name: str,
users: str = Form(default=str())) -> Response: users: str = Form(default=str())) -> Response:
# Get the PackageBase. # Get the PackageBase.
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
# Unauthorized users (Non-TU/Dev and not the pkgbase maintainer) # Unauthorized users (Non-TU/Dev and not the pkgbase maintainer)
# get redirected to the package base's page. # get redirected to the package base's page.
@ -540,7 +530,7 @@ async def package_base_comaintainers_post(
# Get the highest priority in the comaintainer set. # Get the highest priority in the comaintainer set.
last_priority = pkgbase.comaintainers.order_by( last_priority = pkgbase.comaintainers.order_by(
PackageComaintainer.Priority.desc() models.PackageComaintainer.Priority.desc()
).limit(1).first() ).limit(1).first()
# If that record exists, we use a priority which is 1 higher. # If that record exists, we use a priority which is 1 higher.
@ -562,7 +552,8 @@ async def package_base_comaintainers_post(
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
memo = {} memo = {}
for username in usernames: for username in usernames:
user = db.query(User).filter(User.Username == username).first() user = db.query(models.User).filter(
models.User.Username == username).first()
if not user: if not user:
return _("Invalid user name: %s") % username return _("Invalid user name: %s") % username
memo[username] = user memo[username] = user
@ -579,7 +570,7 @@ async def package_base_comaintainers_post(
# If we get here, our user model object is in the memo. # If we get here, our user model object is in the memo.
comaintainer = db.create( comaintainer = db.create(
PackageComaintainer, models.PackageComaintainer,
PackageBase=pkgbase, PackageBase=pkgbase,
User=user, User=user,
Priority=priority) Priority=priority)
@ -620,21 +611,21 @@ async def requests(request: Request,
context["PP"] = PP context["PP"] = PP
# A PackageRequest query, with left inner joined User and RequestType. # A PackageRequest query, with left inner joined User and RequestType.
query = db.query(PackageRequest).join( query = db.query(models.PackageRequest).join(
User, PackageRequest.UsersID == User.ID models.User, models.PackageRequest.UsersID == models.User.ID
).join(RequestType) ).join(models.RequestType)
# If the request user is not elevated (TU or Dev), then # If the request user is not elevated (TU or Dev), then
# filter PackageRequests which are owned by the request user. # filter PackageRequests which are owned by the request user.
if not request.user.is_elevated(): if not request.user.is_elevated():
query = query.filter(PackageRequest.UsersID == request.user.ID) query = query.filter(models.PackageRequest.UsersID == request.user.ID)
context["total"] = query.count() context["total"] = query.count()
context["results"] = query.order_by( context["results"] = query.order_by(
# Order primarily by the Status column being PENDING_ID, # Order primarily by the Status column being PENDING_ID,
# and secondarily by RequestTS; both in descending order. # and secondarily by RequestTS; both in descending order.
case([(PackageRequest.Status == PENDING_ID, 1)], else_=0).desc(), case([(models.PackageRequest.Status == PENDING_ID, 1)], else_=0).desc(),
PackageRequest.RequestTS.desc() models.PackageRequest.RequestTS.desc()
).limit(PP).offset(O).all() ).limit(PP).offset(O).all()
return render_template(request, "requests.html", context) return render_template(request, "requests.html", context)
@ -645,7 +636,8 @@ async def requests(request: Request,
async def package_request(request: Request, name: str): async def package_request(request: Request, name: str):
context = make_context(request, "Submit Request") context = make_context(request, "Submit Request")
pkgbase = db.query(PackageBase).filter(PackageBase.Name == name).first() pkgbase = db.query(models.PackageBase).filter(
models.PackageBase.Name == name).first()
if not pkgbase: if not pkgbase:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
@ -660,7 +652,7 @@ async def pkgbase_request_post(request: Request, name: str,
type: str = Form(...), type: str = Form(...),
merge_into: str = Form(default=None), merge_into: str = Form(default=None),
comments: str = Form(default=str())): comments: str = Form(default=str())):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
# Create our render context. # Create our render context.
context = make_context(request, "Submit Request") context = make_context(request, "Submit Request")
@ -682,8 +674,8 @@ async def pkgbase_request_post(request: Request, name: str,
context["errors"] = ['The "Merge into" field must not be empty.'] context["errors"] = ['The "Merge into" field must not be empty.']
return render_template(request, "pkgbase/request.html", context) return render_template(request, "pkgbase/request.html", context)
target = db.query(PackageBase).filter( target = db.query(models.PackageBase).filter(
PackageBase.Name == merge_into models.PackageBase.Name == merge_into
).first() ).first()
if not target: if not target:
# TODO: This error needs to be translated. # TODO: This error needs to be translated.
@ -701,12 +693,14 @@ async def pkgbase_request_post(request: Request, name: str,
# All good. Create a new PackageRequest based on the given type. # All good. Create a new PackageRequest based on the given type.
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
reqtype = db.query(RequestType, RequestType.Name == type).first() reqtype = db.query(models.RequestType).filter(
models.RequestType.Name == type).first()
conn = db.ConnectionExecutor(db.get_engine().raw_connection()) conn = db.ConnectionExecutor(db.get_engine().raw_connection())
notify_ = None notify_ = None
with db.begin(): with db.begin():
pkgreq = db.create(PackageRequest, RequestType=reqtype, RequestTS=now, pkgreq = db.create(models.PackageRequest, RequestType=reqtype,
PackageBase=pkgbase, PackageBaseName=pkgbase.Name, RequestTS=now, PackageBase=pkgbase,
PackageBaseName=pkgbase.Name,
MergeBaseName=merge_into, User=request.user, MergeBaseName=merge_into, User=request.user,
Comments=comments, ClosureComment=str()) Comments=comments, ClosureComment=str())
@ -726,7 +720,8 @@ async def pkgbase_request_post(request: Request, name: str,
@router.get("/requests/{id}/close") @router.get("/requests/{id}/close")
@auth_required(True, redirect="/requests/{id}/close") @auth_required(True, redirect="/requests/{id}/close")
async def requests_close(request: Request, id: int): async def requests_close(request: Request, id: int):
pkgreq = db.query(PackageRequest).filter(PackageRequest.ID == id).first() pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
if not request.user.is_elevated() and request.user != pkgreq.User: if not request.user.is_elevated() and request.user != pkgreq.User:
# Request user doesn't have permission here: redirect to '/'. # Request user doesn't have permission here: redirect to '/'.
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
@ -741,7 +736,8 @@ async def requests_close(request: Request, id: int):
async def requests_close_post(request: Request, id: int, async def requests_close_post(request: Request, id: int,
reason: int = Form(default=0), reason: int = Form(default=0),
comments: str = Form(default=str())): comments: str = Form(default=str())):
pkgreq = db.query(PackageRequest).filter(PackageRequest.ID == id).first() pkgreq = db.query(models.PackageRequest).filter(
models.PackageRequest.ID == id).first()
if not request.user.is_elevated() and request.user != pkgreq.User: if not request.user.is_elevated() and request.user != pkgreq.User:
# Request user doesn't have permission here: redirect to '/'. # Request user doesn't have permission here: redirect to '/'.
return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER) return RedirectResponse("/", status_code=HTTPStatus.SEE_OTHER)
@ -776,7 +772,7 @@ async def requests_close_post(request: Request, id: int,
@router.get("/pkgbase/{name}/flag") @router.get("/pkgbase/{name}/flag")
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_flag_get(request: Request, name: str): async def pkgbase_flag_get(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
has_cred = request.user.has_credential("CRED_PKGBASE_FLAG") has_cred = request.user.has_credential("CRED_PKGBASE_FLAG")
if not has_cred or pkgbase.Flagger is not None: if not has_cred or pkgbase.Flagger is not None:
@ -792,7 +788,7 @@ async def pkgbase_flag_get(request: Request, name: str):
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_flag_post(request: Request, name: str, async def pkgbase_flag_post(request: Request, name: str,
comments: str = Form(default=str())): comments: str = Form(default=str())):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
if not comments: if not comments:
context = make_context(request, "Flag Package Out-Of-Date") context = make_context(request, "Flag Package Out-Of-Date")
@ -817,7 +813,7 @@ async def pkgbase_flag_post(request: Request, name: str,
@router.post("/pkgbase/{name}/unflag") @router.post("/pkgbase/{name}/unflag")
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_unflag(request: Request, name: str): async def pkgbase_unflag(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
has_cred = request.user.has_credential( has_cred = request.user.has_credential(
"CRED_PKGBASE_UNFLAG", approved=[pkgbase.Flagger, pkgbase.Maintainer]) "CRED_PKGBASE_UNFLAG", approved=[pkgbase.Flagger, pkgbase.Maintainer])
@ -834,15 +830,15 @@ async def pkgbase_unflag(request: Request, name: str):
@router.post("/pkgbase/{name}/notify") @router.post("/pkgbase/{name}/notify")
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_notify(request: Request, name: str): async def pkgbase_notify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
notif = db.query(pkgbase.notifications.filter( notif = db.query(pkgbase.notifications.filter(
PackageNotification.UserID == request.user.ID models.PackageNotification.UserID == request.user.ID
).exists()).scalar() ).exists()).scalar()
has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY") has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY")
if has_cred and not notif: if has_cred and not notif:
with db.begin(): with db.begin():
db.create(PackageNotification, db.create(models.PackageNotification,
PackageBase=pkgbase, PackageBase=pkgbase,
User=request.user) User=request.user)
@ -853,10 +849,10 @@ async def pkgbase_notify(request: Request, name: str):
@router.post("/pkgbase/{name}/unnotify") @router.post("/pkgbase/{name}/unnotify")
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_unnotify(request: Request, name: str): async def pkgbase_unnotify(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
notif = pkgbase.notifications.filter( notif = pkgbase.notifications.filter(
PackageNotification.UserID == request.user.ID models.PackageNotification.UserID == request.user.ID
).first() ).first()
has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY") has_cred = request.user.has_credential("CRED_PKGBASE_NOTIFY")
if has_cred and notif: if has_cred and notif:
@ -870,16 +866,16 @@ async def pkgbase_unnotify(request: Request, name: str):
@router.post("/pkgbase/{name}/vote") @router.post("/pkgbase/{name}/vote")
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_vote(request: Request, name: str): async def pkgbase_vote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
vote = pkgbase.package_votes.filter( vote = pkgbase.package_votes.filter(
PackageVote.UsersID == request.user.ID models.PackageVote.UsersID == request.user.ID
).first() ).first()
has_cred = request.user.has_credential("CRED_PKGBASE_VOTE") has_cred = request.user.has_credential("CRED_PKGBASE_VOTE")
if has_cred and not vote: if has_cred and not vote:
now = int(datetime.utcnow().timestamp()) now = int(datetime.utcnow().timestamp())
with db.begin(): with db.begin():
db.create(PackageVote, db.create(models.PackageVote,
User=request.user, User=request.user,
PackageBase=pkgbase, PackageBase=pkgbase,
VoteTS=now) VoteTS=now)
@ -895,10 +891,10 @@ async def pkgbase_vote(request: Request, name: str):
@router.post("/pkgbase/{name}/unvote") @router.post("/pkgbase/{name}/unvote")
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_unvote(request: Request, name: str): async def pkgbase_unvote(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
vote = pkgbase.package_votes.filter( vote = pkgbase.package_votes.filter(
PackageVote.UsersID == request.user.ID models.PackageVote.UsersID == request.user.ID
).first() ).first()
has_cred = request.user.has_credential("CRED_PKGBASE_VOTE") has_cred = request.user.has_credential("CRED_PKGBASE_VOTE")
if has_cred and vote: if has_cred and vote:
@ -913,7 +909,7 @@ async def pkgbase_unvote(request: Request, name: str):
status_code=HTTPStatus.SEE_OTHER) status_code=HTTPStatus.SEE_OTHER)
def disown_pkgbase(pkgbase: PackageBase, disowner: User): def disown_pkgbase(pkgbase: models.PackageBase, disowner: models.User):
conn = db.ConnectionExecutor(db.get_engine().raw_connection()) conn = db.ConnectionExecutor(db.get_engine().raw_connection())
notif = notify.DisownNotification(conn, disowner.ID, pkgbase.ID) notif = notify.DisownNotification(conn, disowner.ID, pkgbase.ID)
@ -922,7 +918,7 @@ def disown_pkgbase(pkgbase: PackageBase, disowner: User):
pkgbase.Maintainer = None pkgbase.Maintainer = None
else: else:
co = pkgbase.comaintainers.order_by( co = pkgbase.comaintainers.order_by(
PackageComaintainer.Priority.asc() models.PackageComaintainer.Priority.asc()
).limit(1).first() ).limit(1).first()
if co: if co:
@ -938,7 +934,7 @@ def disown_pkgbase(pkgbase: PackageBase, disowner: User):
@router.get("/pkgbase/{name}/disown") @router.get("/pkgbase/{name}/disown")
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_disown_get(request: Request, name: str): async def pkgbase_disown_get(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
has_cred = request.user.has_credential("CRED_PKGBASE_DISOWN", has_cred = request.user.has_credential("CRED_PKGBASE_DISOWN",
approved=[pkgbase.Maintainer]) approved=[pkgbase.Maintainer])
@ -955,7 +951,7 @@ async def pkgbase_disown_get(request: Request, name: str):
@auth_required(True, redirect="/pkgbase/{name}") @auth_required(True, redirect="/pkgbase/{name}")
async def pkgbase_disown_post(request: Request, name: str, async def pkgbase_disown_post(request: Request, name: str,
confirm: bool = Form(default=False)): confirm: bool = Form(default=False)):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
has_cred = request.user.has_credential("CRED_PKGBASE_DISOWN", has_cred = request.user.has_credential("CRED_PKGBASE_DISOWN",
approved=[pkgbase.Maintainer]) approved=[pkgbase.Maintainer])
@ -979,7 +975,7 @@ async def pkgbase_disown_post(request: Request, name: str,
@router.post("/pkgbase/{name}/adopt") @router.post("/pkgbase/{name}/adopt")
@auth_required(True) @auth_required(True)
async def pkgbase_adopt_post(request: Request, name: str): async def pkgbase_adopt_post(request: Request, name: str):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
has_cred = request.user.has_credential("CRED_PKGBASE_ADOPT") has_cred = request.user.has_credential("CRED_PKGBASE_ADOPT")
if has_cred or not pkgbase.Maintainer: if has_cred or not pkgbase.Maintainer:
@ -1001,7 +997,7 @@ async def pkgbase_delete_get(request: Request, name: str):
status_code=HTTPStatus.SEE_OTHER) status_code=HTTPStatus.SEE_OTHER)
context = make_context(request, "Package Deletion") context = make_context(request, "Package Deletion")
context["pkgbase"] = get_pkg_or_base(name, PackageBase) context["pkgbase"] = get_pkg_or_base(name, models.PackageBase)
return render_template(request, "packages/delete.html", context) return render_template(request, "packages/delete.html", context)
@ -1009,7 +1005,7 @@ async def pkgbase_delete_get(request: Request, name: str):
@auth_required(True) @auth_required(True)
async def pkgbase_delete_post(request: Request, name: str, async def pkgbase_delete_post(request: Request, name: str,
confirm: bool = Form(default=False)): confirm: bool = Form(default=False)):
pkgbase = get_pkg_or_base(name, PackageBase) pkgbase = get_pkg_or_base(name, models.PackageBase)
if not request.user.has_credential("CRED_PKGBASE_DELETE"): if not request.user.has_credential("CRED_PKGBASE_DELETE"):
return RedirectResponse(f"/pkgbase/{name}", return RedirectResponse(f"/pkgbase/{name}",

View file

@ -5,8 +5,7 @@ from fastapi.responses import Response
from feedgen.feed import FeedGenerator from feedgen.feed import FeedGenerator
from aurweb import db, util from aurweb import db, util
from aurweb.models.package import Package from aurweb.models import Package, PackageBase
from aurweb.models.package_base import PackageBase
router = APIRouter() router = APIRouter()

View file

@ -10,12 +10,9 @@ from fastapi import APIRouter, Form, HTTPException, Request
from fastapi.responses import RedirectResponse, Response from fastapi.responses import RedirectResponse, Response
from sqlalchemy import and_, or_ from sqlalchemy import and_, or_
from aurweb import db, l10n from aurweb import db, l10n, models
from aurweb.auth import account_type_required, auth_required from aurweb.auth import account_type_required, auth_required
from aurweb.models.account_type import DEVELOPER, TRUSTED_USER, TRUSTED_USER_AND_DEV from aurweb.models.account_type import DEVELOPER, TRUSTED_USER, TRUSTED_USER_AND_DEV
from aurweb.models.tu_vote import TUVote
from aurweb.models.tu_voteinfo import TUVoteInfo
from aurweb.models.user import User
from aurweb.templates import make_context, make_variable_context, render_template from aurweb.templates import make_context, make_variable_context, render_template
router = APIRouter() router = APIRouter()
@ -72,16 +69,18 @@ async def trusted_user(request: Request,
past_by = "desc" past_by = "desc"
context["past_by"] = past_by context["past_by"] = past_by
current_votes = db.query(TUVoteInfo, TUVoteInfo.End > ts).order_by( current_votes = db.query(models.TUVoteInfo).filter(
TUVoteInfo.Submitted.desc()) models.TUVoteInfo.End > ts).order_by(
models.TUVoteInfo.Submitted.desc())
context["current_votes_count"] = current_votes.count() context["current_votes_count"] = current_votes.count()
current_votes = current_votes.limit(pp).offset(current_off) current_votes = current_votes.limit(pp).offset(current_off)
context["current_votes"] = reversed(current_votes.all()) \ context["current_votes"] = reversed(current_votes.all()) \
if current_by == "asc" else current_votes.all() if current_by == "asc" else current_votes.all()
context["current_off"] = current_off context["current_off"] = current_off
past_votes = db.query(TUVoteInfo, TUVoteInfo.End <= ts).order_by( past_votes = db.query(models.TUVoteInfo).filter(
TUVoteInfo.Submitted.desc()) models.TUVoteInfo.End <= ts).order_by(
models.TUVoteInfo.Submitted.desc())
context["past_votes_count"] = past_votes.count() context["past_votes_count"] = past_votes.count()
past_votes = past_votes.limit(pp).offset(past_off) past_votes = past_votes.limit(pp).offset(past_off)
context["past_votes"] = reversed(past_votes.all()) \ context["past_votes"] = reversed(past_votes.all()) \
@ -92,14 +91,14 @@ async def trusted_user(request: Request,
# We order last votes by TUVote.VoteID and User.Username. # We order last votes by TUVote.VoteID and User.Username.
# This is really bad. We should add a Created column to # This is really bad. We should add a Created column to
# TUVote of type Timestamp and order by that instead. # TUVote of type Timestamp and order by that instead.
last_votes_by_tu = db.query(TUVote).filter( last_votes_by_tu = db.query(models.TUVote).filter(
and_(TUVote.VoteID == TUVoteInfo.ID, and_(models.TUVote.VoteID == models.TUVoteInfo.ID,
TUVoteInfo.End <= ts, models.TUVoteInfo.End <= ts,
TUVote.UserID == User.ID, models.TUVote.UserID == models.User.ID,
or_(User.AccountTypeID == 2, or_(models.User.AccountTypeID == 2,
User.AccountTypeID == 4)) models.User.AccountTypeID == 4))
).group_by(User.ID).order_by( ).group_by(models.User.ID).order_by(
TUVote.VoteID.desc(), User.Username.asc()) models.TUVote.VoteID.desc(), models.User.Username.asc())
context["last_votes_by_tu"] = last_votes_by_tu.all() context["last_votes_by_tu"] = last_votes_by_tu.all()
context["current_by_next"] = "asc" if current_by == "desc" else "desc" context["current_by_next"] = "asc" if current_by == "desc" else "desc"
@ -118,9 +117,9 @@ async def trusted_user(request: Request,
def render_proposal(request: Request, def render_proposal(request: Request,
context: dict, context: dict,
proposal: int, proposal: int,
voteinfo: TUVoteInfo, voteinfo: models.TUVoteInfo,
voters: typing.Iterable[User], voters: typing.Iterable[models.User],
vote: TUVote, vote: models.TUVote,
status_code: HTTPStatus = HTTPStatus.OK): status_code: HTTPStatus = HTTPStatus.OK):
""" Render a single TU proposal. """ """ Render a single TU proposal. """
context["proposal"] = proposal context["proposal"] = proposal
@ -135,7 +134,7 @@ def render_proposal(request: Request,
(participation > voteinfo.Quorum and voteinfo.Yes > voteinfo.No) (participation > voteinfo.Quorum and voteinfo.Yes > voteinfo.No)
context["accepted"] = accepted context["accepted"] = accepted
can_vote = voters.filter(TUVote.User == request.user).first() is None can_vote = voters.filter(models.TUVote.User == request.user).first() is None
context["can_vote"] = can_vote context["can_vote"] = can_vote
if not voteinfo.is_running(): if not voteinfo.is_running():
@ -155,13 +154,16 @@ async def trusted_user_proposal(request: Request, proposal: int):
context = await make_variable_context(request, "Trusted User") context = await make_variable_context(request, "Trusted User")
proposal = int(proposal) proposal = int(proposal)
voteinfo = db.query(TUVoteInfo, TUVoteInfo.ID == proposal).first() voteinfo = db.query(models.TUVoteInfo).filter(
models.TUVoteInfo.ID == proposal).first()
if not voteinfo: if not voteinfo:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
voters = db.query(User).join(TUVote).filter(TUVote.VoteID == voteinfo.ID) voters = db.query(models.User).join(models.TUVote).filter(
vote = db.query(TUVote, and_(TUVote.UserID == request.user.ID, models.TUVote.VoteID == voteinfo.ID)
TUVote.VoteID == voteinfo.ID)).first() vote = db.query(models.TUVote).filter(
and_(models.TUVote.UserID == request.user.ID,
models.TUVote.VoteID == voteinfo.ID)).first()
if not request.user.is_trusted_user(): if not request.user.is_trusted_user():
context["error"] = "Only Trusted Users are allowed to vote." context["error"] = "Only Trusted Users are allowed to vote."
@ -183,13 +185,16 @@ async def trusted_user_proposal_post(request: Request,
context = await make_variable_context(request, "Trusted User") context = await make_variable_context(request, "Trusted User")
proposal = int(proposal) # Make sure it's an int. proposal = int(proposal) # Make sure it's an int.
voteinfo = db.query(TUVoteInfo, TUVoteInfo.ID == proposal).first() voteinfo = db.query(models.TUVoteInfo).filter(
models.TUVoteInfo.ID == proposal).first()
if not voteinfo: if not voteinfo:
raise HTTPException(status_code=HTTPStatus.NOT_FOUND) raise HTTPException(status_code=HTTPStatus.NOT_FOUND)
voters = db.query(User).join(TUVote).filter(TUVote.VoteID == voteinfo.ID) voters = db.query(models.User).join(models.TUVote).filter(
vote = db.query(TUVote, and_(TUVote.UserID == request.user.ID, models.TUVote.VoteID == voteinfo.ID)
TUVote.VoteID == voteinfo.ID)).first() vote = db.query(models.TUVote).filter(
and_(models.TUVote.UserID == request.user.ID,
models.TUVote.VoteID == voteinfo.ID)).first()
status_code = HTTPStatus.OK status_code = HTTPStatus.OK
if not request.user.is_trusted_user(): if not request.user.is_trusted_user():
@ -215,7 +220,7 @@ async def trusted_user_proposal_post(request: Request,
status_code=HTTPStatus.BAD_REQUEST) status_code=HTTPStatus.BAD_REQUEST)
with db.begin(): with db.begin():
vote = db.create(TUVote, User=request.user, VoteInfo=voteinfo) vote = db.create(models.TUVote, User=request.user, VoteInfo=voteinfo)
voteinfo.ActiveTUs += 1 voteinfo.ActiveTUs += 1
context["error"] = "You've already voted for this proposal." context["error"] = "You've already voted for this proposal."
@ -262,12 +267,14 @@ async def trusted_user_addvote_post(request: Request,
# Alright, get some database records, if we can. # Alright, get some database records, if we can.
if type != "bylaws": if type != "bylaws":
user_record = db.query(User, User.Username == user).first() user_record = db.query(models.User).filter(
models.User.Username == user).first()
if user_record is None: if user_record is None:
context["error"] = "Username does not exist." context["error"] = "Username does not exist."
return render_addvote(context, HTTPStatus.NOT_FOUND) return render_addvote(context, HTTPStatus.NOT_FOUND)
voteinfo = db.query(TUVoteInfo, TUVoteInfo.User == user).count() voteinfo = db.query(models.TUVoteInfo).filter(
models.TUVoteInfo.User == user).count()
if voteinfo: if voteinfo:
_ = l10n.get_translator_for_request(request) _ = l10n.get_translator_for_request(request)
context["error"] = _( context["error"] = _(
@ -288,13 +295,14 @@ async def trusted_user_addvote_post(request: Request,
duration, quorum = ADDVOTE_SPECIFICS.get(type) duration, quorum = ADDVOTE_SPECIFICS.get(type)
timestamp = int(datetime.utcnow().timestamp()) timestamp = int(datetime.utcnow().timestamp())
# TODO: Review this. Is this even necessary?
# Remove <script> and <style> tags. # Remove <script> and <style> tags.
agenda = re.sub(r'<[/]?script.*>', '', agenda) agenda = re.sub(r'<[/]?script.*>', '', agenda)
agenda = re.sub(r'<[/]?style.*>', '', agenda) agenda = re.sub(r'<[/]?style.*>', '', agenda)
# Create a new TUVoteInfo (proposal)! # Create a new TUVoteInfo (proposal)!
with db.begin(): with db.begin():
voteinfo = db.create(TUVoteInfo, voteinfo = db.create(models.TUVoteInfo,
User=user, User=user,
Agenda=agenda, Agenda=agenda,
Submitted=timestamp, End=timestamp + duration, Submitted=timestamp, End=timestamp + duration,

View file

@ -139,7 +139,7 @@ def to_qs(query: Dict[str, Any]) -> str:
def get_vote(voteinfo, request: fastapi.Request): def get_vote(voteinfo, request: fastapi.Request):
from aurweb.models.tu_vote import TUVote from aurweb.models import TUVote
return voteinfo.tu_votes.filter(TUVote.User == request.user).first() return voteinfo.tu_votes.filter(TUVote.User == request.user).first()