feat(rpc): add search type handler

This commit introduces a PackageSearch-derivative class: `RPCSearch`.
This derivative modifies callback behavior of PackageSearch to
suit RPC searches, including [make|check|opt]depends `by` types.

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2021-10-30 16:39:20 -07:00
parent ece25e0499
commit af2f3694e7
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
5 changed files with 245 additions and 37 deletions

View file

@ -9,6 +9,9 @@ PP = 50
# A whitelist of valid PP values # A whitelist of valid PP values
PP_WHITELIST = {50, 100, 250} PP_WHITELIST = {50, 100, 250}
# Default `by` parameter for RPC search.
RPC_SEARCH_BY = "name-desc"
def fallback_pp(per_page: int) -> int: def fallback_pp(per_page: int) -> int:
""" If `per_page` is a valid value in PP_WHITELIST, return it. """ If `per_page` is a valid value in PP_WHITELIST, return it.

View file

@ -1,6 +1,7 @@
from sqlalchemy import and_, case, or_, orm from sqlalchemy import and_, case, or_, orm
from aurweb import config, db, models from aurweb import config, db, models, util
from aurweb.models.dependency_type import CHECKDEPENDS_ID, DEPENDS_ID, MAKEDEPENDS_ID, OPTDEPENDS_ID
DEFAULT_MAX_RESULTS = 2500 DEFAULT_MAX_RESULTS = 2500
@ -11,24 +12,25 @@ class PackageSearch:
# A constant mapping of short to full name sort orderings. # A constant mapping of short to full name sort orderings.
FULL_SORT_ORDER = {"d": "desc", "a": "asc"} FULL_SORT_ORDER = {"d": "desc", "a": "asc"}
def __init__(self, user: models.User): def __init__(self, user: models.User = None):
""" Construct an instance of PackageSearch.
This constructors performs several steps during initialization:
1. Setup self.query: an ORM query of Package joined by PackageBase.
"""
self.user = user self.user = user
self.query = db.query(models.Package).join(models.PackageBase).join( self.query = db.query(models.Package).join(models.PackageBase)
models.PackageVote,
and_(models.PackageVote.PackageBaseID == models.PackageBase.ID, if self.user:
models.PackageVote.UsersID == self.user.ID), PackageVote = models.PackageVote
isouter=True join_vote_on = and_(
).join( PackageVote.PackageBaseID == models.PackageBase.ID,
models.PackageNotification, PackageVote.UsersID == self.user.ID)
and_(models.PackageNotification.PackageBaseID == models.PackageBase.ID,
models.PackageNotification.UserID == self.user.ID), PackageNotification = models.PackageNotification
isouter=True join_notif_on = and_(
) PackageNotification.PackageBaseID == models.PackageBase.ID,
PackageNotification.UserID == self.user.ID)
self.query = self.query.join(
models.PackageVote, join_vote_on, isouter=True
).join(models.PackageNotification, join_notif_on, isouter=True)
self.ordering = "d" self.ordering = "d"
# Setup SeB (Search By) callbacks. # Setup SeB (Search By) callbacks.
@ -198,3 +200,83 @@ class PackageSearch:
# Return the query to the user. # Return the query to the user.
return self.query return self.query
class RPCSearch(PackageSearch):
""" A PackageSearch-derived RPC package search query builder.
With RPC search, we need a subset of PackageSearch's handlers,
with a few additional handlers added. So, within the RPCSearch
constructor, we pop unneeded keys out of inherited self.search_by_cb
and add a few more keys to it, namely: depends, makedepends,
optdepends and checkdepends.
Additionally, some logic within the inherited PackageSearch.search_by
method is not needed, so it is overridden in this class without
sanitization done for the PackageSearch `by` argument.
"""
keys_removed = ("b", "N", "B", "k", "c", "M", "s")
def __init__(self) -> "RPCSearch":
super().__init__()
# Fix-up inherited search_by_cb to reflect RPC-specific by params.
# We keep: "nd", "n" and "m". We also overlay four new by params
# on top: "depends", "makedepends", "optdepends" and "checkdepends".
util.apply_all(RPCSearch.keys_removed,
lambda k: self.search_by_cb.pop(k))
self.search_by_cb.update({
"depends": self._search_by_depends,
"makedepends": self._search_by_makedepends,
"optdepends": self._search_by_optdepends,
"checkdepends": self._search_by_checkdepends
})
def _join_depends(self, dep_type_id: int) -> orm.Query:
""" Join Package with PackageDependency and filter results
based on `dep_type_id`.
:param dep_type_id: DependencyType ID
:returns: PackageDependency-joined orm.Query
"""
self.query = self.query.join(models.PackageDependency).filter(
models.PackageDependency.DepTypeID == dep_type_id)
return self.query
def _search_by_depends(self, keywords: str) -> "RPCSearch":
self.query = self._join_depends(DEPENDS_ID).filter(
models.PackageDependency.DepName == keywords)
return self
def _search_by_makedepends(self, keywords: str) -> "RPCSearch":
self.query = self._join_depends(MAKEDEPENDS_ID).filter(
models.PackageDependency.DepName == keywords)
return self
def _search_by_optdepends(self, keywords: str) -> "RPCSearch":
self.query = self._join_depends(OPTDEPENDS_ID).filter(
models.PackageDependency.DepName == keywords)
return self
def _search_by_checkdepends(self, keywords: str) -> "RPCSearch":
self.query = self._join_depends(CHECKDEPENDS_ID).filter(
models.PackageDependency.DepName == keywords)
return self
def search_by(self, by: str, keywords: str) -> "RPCSearch":
""" Override inherited search_by. In this override, we reduce the
scope of what we handle within this function. We do not set `by`
to a default of "nd" in the RPC, as the RPC returns an error when
incorrect `by` fields are specified.
:param by: RPC `by` argument
:param keywords: RPC `arg` argument
:returns: self
"""
callback = self.search_by_cb.get(by)
result = callback(keywords)
return result
def results(self) -> orm.Query:
return self.query

View file

@ -9,6 +9,7 @@ import orjson
from fastapi import APIRouter, Query, Request, Response from fastapi import APIRouter, Query, Request, Response
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from aurweb import defaults
from aurweb.ratelimit import check_ratelimit from aurweb.ratelimit import check_ratelimit
from aurweb.rpc import RPC from aurweb.rpc import RPC
@ -62,10 +63,11 @@ def parse_args(request: Request):
@router.get("/rpc") @router.get("/rpc")
async def rpc(request: Request, async def rpc(request: Request,
v: Optional[int] = Query(None), v: Optional[int] = Query(default=None),
type: Optional[str] = Query(None), type: Optional[str] = Query(default=None),
arg: Optional[str] = Query(None), by: Optional[str] = Query(default=defaults.RPC_SEARCH_BY),
args: Optional[List[str]] = Query(None, alias="arg[]")): arg: Optional[str] = Query(default=None),
args: Optional[List[str]] = Query(default=[], alias="arg[]")):
# Create a handle to our RPC class. # Create a handle to our RPC class.
rpc = RPC(version=v, type=type) rpc = RPC(version=v, type=type)
@ -78,7 +80,7 @@ async def rpc(request: Request,
# Prepare list of arguments for input. If 'arg' was given, it'll # Prepare list of arguments for input. If 'arg' was given, it'll
# be a list with one element. # be a list with one element.
arguments = parse_args(request) arguments = parse_args(request)
data = rpc.handle(arguments) data = rpc.handle(by=by, args=arguments)
# Serialize `data` into JSON in a sorted fashion. This way, our # Serialize `data` into JSON in a sorted fashion. This way, our
# ETag header produced below will never end up changed. # ETag header produced below will never end up changed.

View file

@ -5,8 +5,9 @@ from sqlalchemy import and_
import aurweb.config as config import aurweb.config as config
from aurweb import db, models, util from aurweb import db, defaults, models, util
from aurweb.models import dependency_type, relation_type from aurweb.models import dependency_type, relation_type
from aurweb.packages.search import RPCSearch
# Define dependency type mappings from ID to RPC-compatible keys. # Define dependency type mappings from ID to RPC-compatible keys.
DEP_TYPES = { DEP_TYPES = {
@ -60,8 +61,16 @@ class RPC:
"suggest", "suggest-pkgbase" "suggest", "suggest-pkgbase"
} }
# A mapping of aliases. # A mapping of type aliases.
ALIASES = {"info": "multiinfo"} TYPE_ALIASES = {"info": "multiinfo"}
EXPOSED_BYS = {
"name-desc", "name", "maintainer",
"depends", "makedepends", "optdepends", "checkdepends"
}
# A mapping of by aliases.
BY_ALIASES = {"name-desc": "nd", "name": "n", "maintainer": "m"}
def __init__(self, version: int = 0, type: str = None): def __init__(self, version: int = 0, type: str = None):
self.version = version self.version = version
@ -76,14 +85,17 @@ class RPC:
"error": message "error": message
} }
def _verify_inputs(self, args: List[str] = []): def _verify_inputs(self, by: str = [], args: List[str] = []):
if self.version is None: if self.version is None:
raise RPCError("Please specify an API version.") raise RPCError("Please specify an API version.")
if self.version not in RPC.EXPOSED_VERSIONS: if self.version not in RPC.EXPOSED_VERSIONS:
raise RPCError("Invalid version specified.") raise RPCError("Invalid version specified.")
if self.type is None or not len(args): if by not in RPC.EXPOSED_BYS:
raise RPCError("Incorrect by field specified.")
if self.type is None:
raise RPCError("No request type/data specified.") raise RPCError("No request type/data specified.")
if self.type not in RPC.EXPOSED_TYPES: if self.type not in RPC.EXPOSED_TYPES:
@ -95,6 +107,10 @@ class RPC:
raise RPCError( raise RPCError(
f"Request type '{self.type}' is not yet implemented.") f"Request type '{self.type}' is not yet implemented.")
def _enforce_args(self, args: List[str]):
if not args:
raise RPCError("No request type/data specified.")
def _update_json_depends(self, package: models.Package, def _update_json_depends(self, package: models.Package,
data: Dict[str, Any]): data: Dict[str, Any]):
# Walk through all related PackageDependencies and produce # Walk through all related PackageDependencies and produce
@ -169,13 +185,36 @@ class RPC:
self._update_json_relations(package, data) self._update_json_relations(package, data)
return data return data
def _handle_multiinfo_type(self, args: List[str] = []): def _handle_multiinfo_type(self, args: List[str] = [], **kwargs):
self._enforce_args(args)
args = set(args) args = set(args)
packages = db.query(models.Package).filter( packages = db.query(models.Package).filter(
models.Package.Name.in_(args)) models.Package.Name.in_(args))
return [self._get_json_data(pkg) for pkg in packages] return [self._get_json_data(pkg) for pkg in packages]
def _handle_suggest_type(self, args: List[str] = []): def _handle_search_type(self, by: str = defaults.RPC_SEARCH_BY,
args: List[str] = []):
# If `by` isn't maintainer and we don't have any args, raise an error.
# In maintainer's case, return all orphans if there are no args,
# so we need args to pass through to the handler without errors.
if by != "m" and not len(args):
raise RPCError("No request type/data specified.")
arg = args[0]
if len(arg) < 2:
raise RPCError("Query arg too small.")
search = RPCSearch()
search.search_by(by, arg)
max_results = config.getint("options", "max_rpc_results")
results = search.results().limit(max_results)
return [self._get_json_data(pkg) for pkg in results]
def _handle_suggest_type(self, args: List[str] = [], **kwargs):
if not args:
return []
arg = args[0] arg = args[0]
packages = db.query(models.Package).join(models.PackageBase).filter( packages = db.query(models.Package).join(models.PackageBase).filter(
and_(models.PackageBase.PackagerUID.isnot(None), and_(models.PackageBase.PackagerUID.isnot(None),
@ -183,14 +222,17 @@ class RPC:
).order_by(models.Package.Name.asc()).limit(20) ).order_by(models.Package.Name.asc()).limit(20)
return [pkg.Name for pkg in packages] return [pkg.Name for pkg in packages]
def _handle_suggest_pkgbase_type(self, args: List[str] = []): def _handle_suggest_pkgbase_type(self, args: List[str] = [], **kwargs):
if not args:
return []
records = db.query(models.PackageBase).filter( records = db.query(models.PackageBase).filter(
and_(models.PackageBase.PackagerUID.isnot(None), and_(models.PackageBase.PackagerUID.isnot(None),
models.PackageBase.Name.like(f"%{args[0]}%")) models.PackageBase.Name.like(f"%{args[0]}%"))
).order_by(models.PackageBase.Name.asc()).limit(20) ).order_by(models.PackageBase.Name.asc()).limit(20)
return [record.Name for record in records] return [record.Name for record in records]
def handle(self, args: List[str] = []): def handle(self, by: str = defaults.RPC_SEARCH_BY, args: List[str] = []):
""" Request entrypoint. A router should pass v, type and args """ Request entrypoint. A router should pass v, type and args
to this function and expect an output dictionary to be returned. to this function and expect an output dictionary to be returned.
@ -199,22 +241,29 @@ class RPC:
:param args: Deciphered list of arguments based on arg/arg[] inputs :param args: Deciphered list of arguments based on arg/arg[] inputs
""" """
# Convert type aliased types. # Convert type aliased types.
if self.type in RPC.ALIASES: if self.type in RPC.TYPE_ALIASES:
self.type = RPC.ALIASES.get(self.type) self.type = RPC.TYPE_ALIASES.get(self.type)
# Prepare our output data dictionary with some basic keys. # Prepare our output data dictionary with some basic keys.
data = {"version": self.version, "type": self.type} data = {"version": self.version, "type": self.type}
# Run some verification on our given arguments. # Run some verification on our given arguments.
try: try:
self._verify_inputs(args) self._verify_inputs(by=by, args=args)
except RPCError as exc: except RPCError as exc:
return self.error(str(exc)) return self.error(str(exc))
# Convert by to its aliased value if it has one.
if by in RPC.BY_ALIASES:
by = RPC.BY_ALIASES.get(by)
# Get a handle to our callback and trap an RPCError with # Get a handle to our callback and trap an RPCError with
# an empty list of results based on callback's execution. # an empty list of results based on callback's execution.
callback = getattr(self, f"_handle_{self.type.replace('-', '_')}_type") callback = getattr(self, f"_handle_{self.type.replace('-', '_')}_type")
results = callback(args) try:
results = callback(by=by, args=args)
except RPCError as exc:
return self.error(str(exc))
# These types are special: we produce a different kind of # These types are special: we produce a different kind of
# successful JSON output: a list of results. # successful JSON output: a list of results.

View file

@ -461,6 +461,11 @@ def test_rpc_suggest_pkgbase():
data = response.json() data = response.json()
assert data == ["chungy-chungus"] assert data == ["chungy-chungus"]
# Test no arg supplied.
response = make_request("/rpc?v=5&type=suggest-pkgbase")
data = response.json()
assert data == []
def test_rpc_suggest(): def test_rpc_suggest():
response = make_request("/rpc?v=5&type=suggest&arg=other") response = make_request("/rpc?v=5&type=suggest&arg=other")
@ -472,9 +477,14 @@ def test_rpc_suggest():
data = response.json() data = response.json()
assert data == [] assert data == []
# Test no arg supplied.
response = make_request("/rpc?v=5&type=suggest")
data = response.json()
assert data == []
def test_rpc_unimplemented_types(): def test_rpc_unimplemented_types():
unimplemented = ["search", "msearch"] unimplemented = ["msearch"]
for type in unimplemented: for type in unimplemented:
response = make_request(f"/rpc?v=5&type={type}&arg=big") response = make_request(f"/rpc?v=5&type={type}&arg=big")
data = response.json() data = response.json()
@ -518,3 +528,65 @@ def test_rpc_etag():
assert response1.headers.get("ETag") is not None assert response1.headers.get("ETag") is not None
assert response1.headers.get("ETag") != str() assert response1.headers.get("ETag") != str()
assert response1.headers.get("ETag") == response2.headers.get("ETag") assert response1.headers.get("ETag") == response2.headers.get("ETag")
def test_rpc_search_arg_too_small():
response = make_request("/rpc?v=5&type=search&arg=b")
assert response.status_code == int(HTTPStatus.OK)
assert response.json().get("error") == "Query arg too small."
def test_rpc_search():
response = make_request("/rpc?v=5&type=search&arg=big")
assert response.status_code == int(HTTPStatus.OK)
data = response.json()
assert data.get("resultcount") == 1
result = data.get("results")[0]
assert result.get("Name") == "big-chungus"
# No args on non-m by types return an error.
response = make_request("/rpc?v=5&type=search")
assert response.json().get("error") == "No request type/data specified."
def test_rpc_search_depends():
response = make_request(
"/rpc?v=5&type=search&by=depends&arg=chungus-depends")
data = response.json()
assert data.get("resultcount") == 1
result = data.get("results")[0]
assert result.get("Name") == "big-chungus"
def test_rpc_search_makedepends():
response = make_request(
"/rpc?v=5&type=search&by=makedepends&arg=chungus-makedepends")
data = response.json()
assert data.get("resultcount") == 1
result = data.get("results")[0]
assert result.get("Name") == "big-chungus"
def test_rpc_search_optdepends():
response = make_request(
"/rpc?v=5&type=search&by=optdepends&arg=chungus-optdepends")
data = response.json()
assert data.get("resultcount") == 1
result = data.get("results")[0]
assert result.get("Name") == "big-chungus"
def test_rpc_search_checkdepends():
response = make_request(
"/rpc?v=5&type=search&by=checkdepends&arg=chungus-checkdepends")
data = response.json()
assert data.get("resultcount") == 1
result = data.get("results")[0]
assert result.get("Name") == "big-chungus"
def test_rpc_incorrect_by():
response = make_request("/rpc?v=5&type=search&by=fake&arg=big")
assert response.json().get("error") == "Incorrect by field specified."