feat: Switch to postgres

Migrate from MariaDB to PostgreSQL.

Signed-off-by: moson <moson@archlinux.org>
This commit is contained in:
moson 2023-11-30 15:13:42 +01:00
parent 3220cf886e
commit db8e2458f9
No known key found for this signature in database
GPG key ID: 4A4760AB4EE15296
64 changed files with 572 additions and 629 deletions

View file

@ -60,7 +60,7 @@ other user:
GRANT ALL ON *.* TO 'user'@'localhost' WITH GRANT OPTION
The aurweb platform is intended to use the `mysql` backend, but
The aurweb platform is intended to use the `postgresql` backend, but
the `sqlite` backend is still used for sharness tests. These tests
will soon be replaced with pytest suites and `sqlite` removed.

View file

@ -52,6 +52,7 @@ from sqlalchemy.orm import scoped_session
import aurweb.config
import aurweb.db
import aurweb.schema
from aurweb import aur_logging, initdb, testing
from aurweb.testing.email import Email
from aurweb.testing.git import GitRepository
@ -68,25 +69,28 @@ values.ValueClass = values.MutexValue
def test_engine() -> Engine:
"""
Return a privileged SQLAlchemy engine with no database.
Return a privileged SQLAlchemy engine with default database.
This method is particularly useful for providing an engine that
can be used to create and drop databases from an SQL server.
:return: SQLAlchemy Engine instance (not connected to a database)
:return: SQLAlchemy Engine instance (connected to a default)
"""
unix_socket = aurweb.config.get_with_fallback("database", "socket", None)
socket = aurweb.config.get_with_fallback("database", "socket", None)
host = aurweb.config.get_with_fallback("database", "host", None)
port = aurweb.config.get_with_fallback("database", "port", None)
kwargs = {
"database": aurweb.config.get("database", "name"),
"username": aurweb.config.get("database", "user"),
"password": aurweb.config.get_with_fallback("database", "password", None),
"host": aurweb.config.get("database", "host"),
"port": aurweb.config.get_with_fallback("database", "port", None),
"query": {"unix_socket": unix_socket},
"host": socket if socket else host,
"port": port if not socket else None,
}
backend = aurweb.config.get("database", "backend")
driver = aurweb.db.DRIVERS.get(backend)
return create_engine(URL.create(driver, **kwargs))
return create_engine(URL.create(driver, **kwargs), isolation_level="AUTOCOMMIT")
class AlembicArgs:
@ -116,7 +120,7 @@ def _create_database(engine: Engine, dbname: str) -> None:
# a ProgrammingError. Just drop the database and try
# again. If at that point things still fail, any
# exception will be propogated up to the caller.
conn.execute(f"DROP DATABASE {dbname}")
conn.execute(f"DROP DATABASE {dbname} WITH (FORCE)")
conn.execute(f"CREATE DATABASE {dbname}")
conn.close()
initdb.run(AlembicArgs)
@ -129,9 +133,8 @@ def _drop_database(engine: Engine, dbname: str) -> None:
:param engine: Engine returned by test_engine()
:param dbname: Database name to drop
"""
aurweb.schema.metadata.drop_all(bind=engine)
conn = engine.connect()
conn.execute(f"DROP DATABASE {dbname}")
conn.execute(f"DROP DATABASE {dbname} WITH (FORCE)")
conn.close()
@ -178,6 +181,10 @@ def db_session(setup_database: None) -> scoped_session:
session.close()
aurweb.db.pop_session(dbname)
# Dispose engine and close connections
aurweb.db.get_engine(dbname).dispose()
aurweb.db.pop_engine(dbname)
@pytest.fixture
def db_test(db_session: scoped_session) -> None:

View file

@ -14,7 +14,7 @@ from aurweb.models.user import User
from aurweb.testing.html import get_errors
# Some test global constants.
TEST_USERNAME = "test"
TEST_USERNAME = "Test"
TEST_EMAIL = "test@example.org"
TEST_REFERER = {
"referer": aurweb.config.get("options", "aur_location") + "/login",
@ -54,36 +54,37 @@ def user() -> User:
def test_login_logout(client: TestClient, user: User):
post_data = {"user": "test", "passwd": "testPassword", "next": "/"}
for username in ["test", "TEst"]:
post_data = {"user": username, "passwd": "testPassword", "next": "/"}
with client as request:
# First, let's test get /login.
response = request.get("/login")
assert response.status_code == int(HTTPStatus.OK)
with client as request:
# First, let's test get /login.
response = request.get("/login")
assert response.status_code == int(HTTPStatus.OK)
response = request.post("/login", data=post_data)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
response = request.post("/login", data=post_data)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
# Simulate following the redirect location from above's response.
response = request.get(response.headers.get("location"))
assert response.status_code == int(HTTPStatus.OK)
# Simulate following the redirect location from above's response.
response = request.get(response.headers.get("location"))
assert response.status_code == int(HTTPStatus.OK)
response = request.post("/logout", data=post_data)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
response = request.post("/logout", data=post_data)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
request.cookies = {"AURSID": response.cookies.get("AURSID")}
response = request.post(
"/logout",
data=post_data,
)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
request.cookies = {"AURSID": response.cookies.get("AURSID")}
response = request.post(
"/logout",
data=post_data,
)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
assert "AURSID" not in response.cookies
assert "AURSID" not in response.cookies
def test_login_suspended(client: TestClient, user: User):
with db.begin():
user.Suspended = 1
user.Suspended = True
data = {"user": user.Username, "passwd": "testPassword", "next": "/"}
with client as request:
@ -184,23 +185,23 @@ def test_secure_login(getboolean: mock.Mock, client: TestClient, user: User):
def test_authenticated_login(client: TestClient, user: User):
post_data = {"user": user.Username, "passwd": "testPassword", "next": "/"}
for username in [user.Username.lower(), user.Username.upper()]:
post_data = {"user": username, "passwd": "testPassword", "next": "/"}
with client as request:
# Try to login.
response = request.post("/login", data=post_data)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
assert response.headers.get("location") == "/"
with client as request:
# Try to login.
request.cookies = {}
response = request.post("/login", data=post_data)
assert response.status_code == int(HTTPStatus.SEE_OTHER)
assert response.headers.get("location") == "/"
# Now, let's verify that we get the logged in rendering
# when requesting GET /login as an authenticated user.
# Now, let's verify that we receive 403 Forbidden when we
# try to get /login as an authenticated user.
request.cookies = response.cookies
response = request.get("/login")
# Now, let's verify that we get the logged in rendering
# when requesting GET /login as an authenticated user.
request.cookies = response.cookies
response = request.get("/login")
assert response.status_code == int(HTTPStatus.OK)
assert "Logged-in as: <strong>test</strong>" in response.text
assert response.status_code == int(HTTPStatus.OK)
assert f"Logged-in as: <strong>{user.Username}</strong>" in response.text
def test_unauthenticated_logout_unauthorized(client: TestClient):
@ -370,5 +371,4 @@ def test_generate_unique_sid_exhausted(
assert re.search(expr, caplog.text)
assert "IntegrityError" in caplog.text
expr = r"Duplicate entry .+ for key .+SessionID.+"
assert re.search(expr, response.text)
assert "duplicate key value" in response.text

View file

@ -93,9 +93,9 @@ def make_temp_sqlite_config():
)
def make_temp_mysql_config():
def make_temp_postgres_config():
return make_temp_config(
(r"backend = .*", "backend = mysql"), (r"name = .*", "name = aurweb_test")
(r"backend = .*", "backend = postgres"), (r"name = .*", "name = aurweb_test")
)
@ -114,8 +114,8 @@ def test_sqlalchemy_sqlite_url():
aurweb.config.rehash()
def test_sqlalchemy_mysql_url():
tmpctx, tmp = make_temp_mysql_config()
def test_sqlalchemy_postgres_url():
tmpctx, tmp = make_temp_postgres_config()
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
@ -123,8 +123,8 @@ def test_sqlalchemy_mysql_url():
aurweb.config.rehash()
def test_sqlalchemy_mysql_port_url():
tmpctx, tmp = make_temp_config((r";port = 3306", "port = 3306"))
def test_sqlalchemy_postgres_port_url():
tmpctx, tmp = make_temp_config((r";port = 5432", "port = 5432"))
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
@ -133,7 +133,7 @@ def test_sqlalchemy_mysql_port_url():
aurweb.config.rehash()
def test_sqlalchemy_mysql_socket_url():
def test_sqlalchemy_postgres_socket_url():
tmpctx, tmp = make_temp_config()
with tmpctx:
@ -170,16 +170,6 @@ def test_connection_class_unsupported_backend():
aurweb.config.rehash()
@mock.patch("MySQLdb.connect", mock.MagicMock(return_value=True))
def test_connection_mysql():
tmpctx, tmp = make_temp_mysql_config()
with tmpctx:
with mock.patch.dict(os.environ, {"AUR_CONFIG": tmp}):
aurweb.config.rehash()
db.Connection()
aurweb.config.rehash()
def test_create_delete():
with db.begin():
account_type = db.create(AccountType, AccountType="test")
@ -212,8 +202,8 @@ def test_add_commit():
db.delete(account_type)
def test_connection_executor_mysql_paramstyle():
executor = db.ConnectionExecutor(None, backend="mysql")
def test_connection_executor_postgres_paramstyle():
executor = db.ConnectionExecutor(None, backend="postgres")
assert executor.paramstyle() == "format"

View file

@ -20,7 +20,7 @@ def test_run():
from aurweb.schema import metadata
aurweb.db.kill_engine()
metadata.drop_all(aurweb.db.get_engine())
metadata.drop_all(aurweb.db.get_engine(), checkfirst=False)
aurweb.initdb.run(Args())
# Check that constant table rows got added via initdb.

View file

@ -227,7 +227,7 @@ please go to the package page [2] and select "Disable notifications".
def test_update(user: User, user2: User, pkgbases: list[PackageBase]):
pkgbase = pkgbases[0]
with db.begin():
user.UpdateNotify = 1
user.UpdateNotify = True
notif = notify.UpdateNotification(user2.ID, pkgbase.ID)
notif.send()
@ -331,7 +331,7 @@ You were removed from the co-maintainer list of {pkgbase.Name} [1].
def test_suspended_ownership_change(user: User, pkgbases: list[PackageBase]):
with db.begin():
user.Suspended = 1
user.Suspended = True
pkgbase = pkgbases[0]
notif = notify.ComaintainerAddNotification(user.ID, pkgbase.ID)
@ -491,7 +491,7 @@ def test_open_close_request_hidden_email(
# Enable the "HideEmail" option for our requester
with db.begin():
user2.HideEmail = 1
user2.HideEmail = True
# Send an open request notification.
notif = notify.RequestOpenNotification(

View file

@ -350,7 +350,7 @@ def test_pm_index_table_paging(client, pm_user):
VoteInfo,
Agenda=f"Agenda #{i}",
User=pm_user.Username,
Submitted=(ts - 5),
Submitted=(ts - 5 - i),
End=(ts + 1000),
Quorum=0.0,
Submitter=pm_user,
@ -362,7 +362,7 @@ def test_pm_index_table_paging(client, pm_user):
VoteInfo,
Agenda=f"Agenda #{25 + i}",
User=pm_user.Username,
Submitted=(ts - 1000),
Submitted=(ts - 1000 - i),
End=(ts - 5),
Quorum=0.0,
Submitter=pm_user,

View file

@ -742,14 +742,15 @@ def test_packages_empty(client: TestClient):
def test_packages_search_by_name(client: TestClient, packages: list[Package]):
with client as request:
response = request.get("/packages", params={"SeB": "n", "K": "pkg_"})
assert response.status_code == int(HTTPStatus.OK)
for keyword in ["pkg_", "PkG_"]:
with client as request:
response = request.get("/packages", params={"SeB": "n", "K": keyword})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 50 # Default per-page
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 50 # Default per-page
def test_packages_search_by_exact_name(client: TestClient, packages: list[Package]):
@ -763,26 +764,28 @@ def test_packages_search_by_exact_name(client: TestClient, packages: list[Packag
# There is no package named exactly 'pkg_', we get 0 results.
assert len(rows) == 0
with client as request:
response = request.get("/packages", params={"SeB": "N", "K": "pkg_1"})
assert response.status_code == int(HTTPStatus.OK)
for keyword in ["pkg_1", "PkG_1"]:
with client as request:
response = request.get("/packages", params={"SeB": "N", "K": keyword})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
# There's just one package named 'pkg_1', we get 1 result.
assert len(rows) == 1
# There's just one package named 'pkg_1', we get 1 result.
assert len(rows) == 1
def test_packages_search_by_pkgbase(client: TestClient, packages: list[Package]):
with client as request:
response = request.get("/packages", params={"SeB": "b", "K": "pkg_"})
assert response.status_code == int(HTTPStatus.OK)
for keyword in ["pkg_", "PkG_"]:
with client as request:
response = request.get("/packages", params={"SeB": "b", "K": "pkg_"})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 50
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 50
def test_packages_search_by_exact_pkgbase(client: TestClient, packages: list[Package]):
@ -794,13 +797,14 @@ def test_packages_search_by_exact_pkgbase(client: TestClient, packages: list[Pac
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 0
with client as request:
response = request.get("/packages", params={"SeB": "B", "K": "pkg_1"})
assert response.status_code == int(HTTPStatus.OK)
for keyword in ["pkg_1", "PkG_1"]:
with client as request:
response = request.get("/packages", params={"SeB": "B", "K": "pkg_1"})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
def test_packages_search_by_keywords(client: TestClient, packages: list[Package]):
@ -821,15 +825,16 @@ def test_packages_search_by_keywords(client: TestClient, packages: list[Package]
)
# And request packages with that keyword, we should get 1 result.
with client as request:
# clear fakeredis cache
cache._redis.flushall()
response = request.get("/packages", params={"SeB": "k", "K": "testKeyword"})
assert response.status_code == int(HTTPStatus.OK)
for keyword in ["testkeyword", "TestKeyWord"]:
with client as request:
# clear fakeredis cache
cache._redis.flushall()
response = request.get("/packages", params={"SeB": "k", "K": keyword})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
# Now let's add another keyword to the same package
with db.begin():
@ -854,14 +859,13 @@ def test_packages_search_by_maintainer(
):
# We should expect that searching by `package`'s maintainer
# returns `package` in the results.
with client as request:
response = request.get(
"/packages", params={"SeB": "m", "K": maintainer.Username}
)
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
for keyword in [maintainer.Username, maintainer.Username.upper()]:
with client as request:
response = request.get("/packages", params={"SeB": "m", "K": keyword})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
# Search again by maintainer with no keywords given.
# This kind of search returns all orphans instead.
@ -912,17 +916,16 @@ def test_packages_search_by_comaintainer(
)
# Then test that it's returned by our search.
with client as request:
# clear fakeredis cache
cache._redis.flushall()
response = request.get(
"/packages", params={"SeB": "c", "K": maintainer.Username}
)
assert response.status_code == int(HTTPStatus.OK)
for keyword in [maintainer.Username, maintainer.Username.upper()]:
with client as request:
# clear fakeredis cache
cache._redis.flushall()
response = request.get("/packages", params={"SeB": "c", "K": keyword})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
def test_packages_search_by_co_or_maintainer(
@ -954,27 +957,27 @@ def test_packages_search_by_co_or_maintainer(
PackageComaintainer, PackageBase=package.PackageBase, User=user, Priority=1
)
with client as request:
response = request.get("/packages", params={"SeB": "M", "K": user.Username})
assert response.status_code == int(HTTPStatus.OK)
for keyword in [user.Username, user.Username.upper()]:
with client as request:
response = request.get("/packages", params={"SeB": "M", "K": keyword})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
def test_packages_search_by_submitter(
client: TestClient, maintainer: User, package: Package
):
with client as request:
response = request.get(
"/packages", params={"SeB": "s", "K": maintainer.Username}
)
assert response.status_code == int(HTTPStatus.OK)
for keyword in [maintainer.Username, maintainer.Username.upper()]:
with client as request:
response = request.get("/packages", params={"SeB": "s", "K": keyword})
assert response.status_code == int(HTTPStatus.OK)
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
root = parse_root(response.text)
rows = root.xpath('//table[@class="results"]/tbody/tr')
assert len(rows) == 1
def test_packages_sort_by_name(client: TestClient, packages: list[Package]):

View file

@ -153,7 +153,7 @@ def test_pkg_required(package: Package):
# We want to make sure "Package" data is included
# to avoid lazy-loading the information for each dependency
qry = util.pkg_required("test", list())
assert "Packages_ID" in str(qry)
assert "packages_id" in str(qry).lower()
# We should have 1 record
assert qry.count() == 1

View file

@ -430,7 +430,7 @@ def test_pkgbase_comments(
# create notification
with db.begin():
user.CommentNotify = 1
user.CommentNotify = True
db.create(PackageNotification, PackageBase=package.PackageBase, User=user)
# post a comment

View file

@ -149,15 +149,15 @@ def assert_multiple_keys(pks):
def test_hash_query():
# No conditions
query = db.query(User)
assert util.hash_query(query) == "75e76026b7d576536e745ec22892cf8f5d7b5d62"
assert util.hash_query(query) == "ebbf077df70d97a1584f91d0dd6ec61e43aa101f"
# With where clause
query = db.query(User).filter(User.Username == "bla")
assert util.hash_query(query) == "4dca710f33b1344c27ec6a3c266970f4fa6a8a00"
assert util.hash_query(query) == "b51f2bfda67051f381a5c05b2946a1aa4d91e56d"
# With where clause and sorting
query = db.query(User).filter(User.Username == "bla").order_by(User.Username)
assert util.hash_query(query) == "ee2c7846fede430776e140f8dfe1d83cd21d2eed"
assert util.hash_query(query) == "8d458bfe1edfe8f78929fab590612e9e5d9db3a5"
# With where clause, sorting and specific columns
query = (
@ -166,4 +166,4 @@ def test_hash_query():
.order_by(User.Username)
.with_entities(User.Username)
)
assert util.hash_query(query) == "c1db751be61443d266cf643005eee7a884dac103"
assert util.hash_query(query) == "006811a386789f25d40a37496f6ac6651413c245"