fix: support multiple SSHPubKey records per user

There was one blazing issue with the previous implementation regardless
of the multiple records: we were generating fingerprints by storing
the key into a file and reading it with ssh-keygen. This is absolutely
terrible and was not meant to be left around (it was forgotten, my bad).

Took this opportunity to clean up a few things:
- simplify pubkey validation
- centralize things a bit better

Signed-off-by: Kevin Morris <kevr@0cost.org>
This commit is contained in:
Kevin Morris 2022-02-08 07:50:15 -08:00
parent 660d57340a
commit 4c14a10b91
No known key found for this signature in database
GPG key ID: F7E46DED420788F3
11 changed files with 162 additions and 108 deletions

View file

@ -1,6 +1,3 @@
import os
import tempfile
from subprocess import PIPE, Popen
from sqlalchemy.orm import backref, relationship
@ -15,28 +12,17 @@ class SSHPubKey(Base):
__mapper_args__ = {"primary_key": [__table__.c.Fingerprint]}
User = relationship(
"User", backref=backref("ssh_pub_key", uselist=False),
"User", backref=backref("ssh_pub_keys", lazy="dynamic"),
foreign_keys=[__table__.c.UserID])
def __init__(self, **kwargs):
super().__init__(**kwargs)
def get_fingerprint(pubkey):
with tempfile.TemporaryDirectory() as tmpdir:
pk = os.path.join(tmpdir, "ssh.pub")
with open(pk, "w") as f:
f.write(pubkey)
proc = Popen(["ssh-keygen", "-l", "-f", pk], stdout=PIPE, stderr=PIPE)
out, err = proc.communicate()
# Invalid SSH Public Key. Return None to the caller.
if proc.returncode != 0:
return None
parts = out.decode().split()
fp = parts[1].replace("SHA256:", "")
return fp
def get_fingerprint(pubkey: str) -> str:
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE,
stderr=PIPE)
out, _ = proc.communicate(pubkey.encode())
if proc.returncode:
raise ValueError("The SSH public key is invalid.")
return out.decode().split()[1].split(":", 1)[1]

View file

@ -2,6 +2,7 @@ import copy
import typing
from http import HTTPStatus
from typing import Any, Dict
from fastapi import APIRouter, Form, Request
from fastapi.responses import HTMLResponse, RedirectResponse
@ -105,7 +106,8 @@ async def passreset_post(request: Request,
status_code=HTTPStatus.SEE_OTHER)
def process_account_form(request: Request, user: models.User, args: dict):
def process_account_form(request: Request, user: models.User,
args: Dict[str, Any]):
""" Process an account form. All fields are optional and only checks
requirements in the case they are present.
@ -193,8 +195,8 @@ def make_account_form_context(context: dict,
context["pgp"] = args.get("K", user.PGPKey or str())
context["lang"] = args.get("L", user.LangPreference)
context["tz"] = args.get("TZ", user.Timezone)
ssh_pk = user.ssh_pub_key.PubKey if user.ssh_pub_key else str()
context["ssh_pk"] = args.get("PK", ssh_pk)
ssh_pks = [pk.PubKey for pk in user.ssh_pub_keys]
context["ssh_pks"] = args.get("PK", ssh_pks)
context["cn"] = args.get("CN", user.CommentNotify)
context["un"] = args.get("UN", user.UpdateNotify)
context["on"] = args.get("ON", user.OwnershipNotify)
@ -212,7 +214,7 @@ def make_account_form_context(context: dict,
context["pgp"] = args.get("K", str())
context["lang"] = args.get("L", context.get("language"))
context["tz"] = args.get("TZ", context.get("timezone"))
context["ssh_pk"] = args.get("PK", str())
context["ssh_pks"] = args.get("PK", str())
context["cn"] = args.get("CN", True)
context["un"] = args.get("UN", False)
context["on"] = args.get("ON", True)
@ -314,16 +316,13 @@ async def account_register_post(request: Request,
# PK mismatches the existing user's SSHPubKey.PubKey.
if PK:
# Get the second element in the PK, which is the actual key.
pubkey = PK.strip().rstrip()
parts = pubkey.split(" ")
if len(parts) == 3:
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
with db.begin():
user.ssh_pub_key = models.SSHPubKey(UserID=user.ID,
PubKey=pubkey,
Fingerprint=fingerprint)
keys = util.parse_ssh_keys(PK.strip())
for k in keys:
pk = " ".join(k)
fprint = get_fingerprint(pk)
with db.begin():
db.create(models.SSHPubKey, UserID=user.ID,
PubKey=pk, Fingerprint=fprint)
# Send a reset key notification to the new user.
WelcomeNotification(user.ID).send()
@ -409,6 +408,9 @@ async def account_edit_post(request: Request,
context = make_account_form_context(context, request, user, args)
ok, errors = process_account_form(request, user, args)
if PK:
context["ssh_pks"] = [PK]
if not passwd:
context["errors"] = ["Invalid password."]
return render_template(request, "account/edit.html", context,

View file

@ -2,7 +2,8 @@ from typing import Any, Dict
from fastapi import Request
from aurweb import cookies, db, models, time
from aurweb import cookies, db, models, time, util
from aurweb.models import SSHPubKey
from aurweb.models.ssh_pub_key import get_fingerprint
from aurweb.util import strtobool
@ -52,32 +53,35 @@ def timezone(TZ: str = str(),
context["language"] = TZ
def ssh_pubkey(PK: str = str(),
user: models.User = None,
**kwargs) -> None:
# If a PK is given, compare it against the target user's PK.
if PK:
# Get the second token in the public key, which is the actual key.
pubkey = PK.strip().rstrip()
parts = pubkey.split(" ")
if len(parts) == 3:
# Remove the host part.
pubkey = parts[0] + " " + parts[1]
fingerprint = get_fingerprint(pubkey)
if not user.ssh_pub_key:
# No public key exists, create one.
with db.begin():
db.create(models.SSHPubKey, UserID=user.ID,
PubKey=pubkey, Fingerprint=fingerprint)
elif user.ssh_pub_key.PubKey != pubkey:
# A public key already exists, update it.
with db.begin():
user.ssh_pub_key.PubKey = pubkey
user.ssh_pub_key.Fingerprint = fingerprint
elif user.ssh_pub_key:
# Else, if the user has a public key already, delete it.
def ssh_pubkey(PK: str = str(), user: models.User = None, **kwargs) -> None:
if not PK:
# If no pubkey is provided, wipe out any pubkeys the user
# has and return out early.
with db.begin():
db.delete(user.ssh_pub_key)
db.delete_all(user.ssh_pub_keys)
return
# Otherwise, parse ssh keys and their fprints out of PK.
keys = util.parse_ssh_keys(PK.strip())
fprints = [get_fingerprint(" ".join(k)) for k in keys]
with db.begin():
# Delete any existing keys we can't find.
to_remove = user.ssh_pub_keys.filter(
~SSHPubKey.Fingerprint.in_(fprints))
db.delete_all(to_remove)
# For each key, if it does not yet exist, create it.
for i, full_key in enumerate(keys):
prefix, key = full_key
exists = user.ssh_pub_keys.filter(
SSHPubKey.Fingerprint == fprints[i]
).exists()
if not db.query(exists).scalar():
# No public key exists, create one.
db.create(models.SSHPubKey, UserID=user.ID,
PubKey=" ".join([prefix, key]),
Fingerprint=fprints[i])
def account_type(T: int = None,

View file

@ -107,14 +107,16 @@ def invalid_pgp_key(K: str = str(), **kwargs) -> None:
def invalid_ssh_pubkey(PK: str = str(), user: models.User = None,
_: l10n.Translator = None, **kwargs) -> None:
if PK:
invalid_exc = ValidationError(["The SSH public key is invalid."])
if not util.valid_ssh_pubkey(PK):
raise invalid_exc
if not PK:
return
fingerprint = get_fingerprint(PK.strip().rstrip())
if not fingerprint:
raise invalid_exc
try:
keys = util.parse_ssh_keys(PK.strip())
except ValueError as exc:
raise ValidationError([str(exc)])
for prefix, key in keys:
fingerprint = get_fingerprint(f"{prefix} {key}")
exists = db.query(models.SSHPubKey).filter(
and_(models.SSHPubKey.UserID != user.ID,

View file

@ -1,4 +1,3 @@
import base64
import math
import re
import secrets
@ -7,7 +6,8 @@ import string
from datetime import datetime
from distutils.util import strtobool as _strtobool
from http import HTTPStatus
from typing import Callable, Iterable, Tuple, Union
from subprocess import PIPE, Popen
from typing import Callable, Iterable, List, Tuple, Union
from urllib.parse import urlparse
import fastapi
@ -82,25 +82,6 @@ def valid_pgp_fingerprint(fp):
return len(fp) == 40
def valid_ssh_pubkey(pk):
valid_prefixes = aurweb.config.get("auth", "valid-keytypes")
valid_prefixes = set(valid_prefixes.split(" "))
has_valid_prefix = False
for prefix in valid_prefixes:
if "%s " % prefix in pk:
has_valid_prefix = True
break
if not has_valid_prefix:
return False
tokens = pk.strip().rstrip().split(" ")
if len(tokens) < 2:
return False
return base64.b64encode(base64.b64decode(tokens[1])).decode() == tokens[1]
def jsonify(obj):
""" Perform a conversion on obj if it's needed. """
if isinstance(obj, datetime):
@ -191,3 +172,29 @@ async def error_or_result(next: Callable, *args, **kwargs) \
status_code = HTTPStatus.INTERNAL_SERVER_ERROR
return JSONResponse({"error": str(exc)}, status_code=status_code)
return response
def parse_ssh_key(string: str) -> Tuple[str, str]:
""" Parse an SSH public key. """
invalid_exc = ValueError("The SSH public key is invalid.")
parts = re.sub(r'\s\s+', ' ', string.strip()).split()
if len(parts) < 2:
raise invalid_exc
prefix, key = parts[:2]
prefixes = set(aurweb.config.get("auth", "valid-keytypes").split(" "))
if prefix not in prefixes:
raise invalid_exc
proc = Popen(["ssh-keygen", "-l", "-f", "-"], stdin=PIPE, stdout=PIPE,
stderr=PIPE)
out, _ = proc.communicate(f"{prefix} {key}".encode())
if proc.returncode:
raise invalid_exc
return (prefix, key)
def parse_ssh_keys(string: str) -> List[Tuple[str, str]]:
""" Parse a list of SSH public keys. """
return [parse_ssh_key(e) for e in string.splitlines()]