diff --git a/src/seedpass/api.py b/src/seedpass/api.py index b10e2e1..18027da 100644 --- a/src/seedpass/api.py +++ b/src/seedpass/api.py @@ -14,6 +14,7 @@ import jwt import logging from fastapi import FastAPI, Header, HTTPException, Request, Response +from fastapi.concurrency import run_in_threadpool import asyncio import sys from fastapi.middleware.cors import CORSMiddleware @@ -132,12 +133,12 @@ def _validate_encryption_path(request: Request, path: Path) -> Path: @app.get("/api/v1/entry") -def search_entry( +async def search_entry( request: Request, query: str, authorization: str | None = Header(None) ) -> List[Any]: _check_token(request, authorization) pm = _get_pm(request) - results = pm.entry_manager.search_entries(query) + results = await run_in_threadpool(pm.entry_manager.search_entries, query) return [ { "id": idx, @@ -152,7 +153,7 @@ def search_entry( @app.get("/api/v1/entry/{entry_id}") -def get_entry( +async def get_entry( request: Request, entry_id: int, authorization: str | None = Header(None), @@ -161,14 +162,14 @@ def get_entry( _check_token(request, authorization) _require_password(request, password) pm = _get_pm(request) - entry = pm.entry_manager.retrieve_entry(entry_id) + entry = await run_in_threadpool(pm.entry_manager.retrieve_entry, entry_id) if entry is None: raise HTTPException(status_code=404, detail="Not found") return entry @app.post("/api/v1/entry") -def create_entry( +async def create_entry( request: Request, entry: dict, authorization: str | None = Header(None), @@ -197,7 +198,8 @@ def create_entry( ] kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None} - index = pm.entry_manager.add_entry( + index = await run_in_threadpool( + pm.entry_manager.add_entry, entry.get("label"), int(entry.get("length", 12)), entry.get("username"), @@ -207,9 +209,10 @@ def create_entry( return {"id": index} if etype == "totp": - index = pm.entry_manager.get_next_index() + index = await run_in_threadpool(pm.entry_manager.get_next_index) - uri = pm.entry_manager.add_totp( + uri = await run_in_threadpool( + pm.entry_manager.add_totp, entry.get("label"), pm.parent_seed, secret=entry.get("secret"), @@ -222,7 +225,8 @@ def create_entry( return {"id": index, "uri": uri} if etype == "ssh": - index = pm.entry_manager.add_ssh_key( + index = await run_in_threadpool( + pm.entry_manager.add_ssh_key, entry.get("label"), pm.parent_seed, index=entry.get("index"), @@ -232,7 +236,8 @@ def create_entry( return {"id": index} if etype == "pgp": - index = pm.entry_manager.add_pgp_key( + index = await run_in_threadpool( + pm.entry_manager.add_pgp_key, entry.get("label"), pm.parent_seed, index=entry.get("index"), @@ -244,7 +249,8 @@ def create_entry( return {"id": index} if etype == "nostr": - index = pm.entry_manager.add_nostr_key( + index = await run_in_threadpool( + pm.entry_manager.add_nostr_key, entry.get("label"), pm.parent_seed, index=entry.get("index"), @@ -254,7 +260,8 @@ def create_entry( return {"id": index} if etype == "key_value": - index = pm.entry_manager.add_key_value( + index = await run_in_threadpool( + pm.entry_manager.add_key_value, entry.get("label"), entry.get("key"), entry.get("value"), @@ -268,7 +275,8 @@ def create_entry( if etype == "seed" else pm.entry_manager.add_managed_account ) - index = func( + index = await run_in_threadpool( + func, entry.get("label"), pm.parent_seed, index=entry.get("index"), diff --git a/src/tests/test_api.py b/src/tests/test_api.py index 9abf7f9..5608361 100644 --- a/src/tests/test_api.py +++ b/src/tests/test_api.py @@ -3,7 +3,7 @@ from pathlib import Path import sys import pytest -from fastapi.testclient import TestClient +from httpx import ASGITransport, AsyncClient import hashlib sys.path.append(str(Path(__file__).resolve().parents[1])) @@ -13,7 +13,7 @@ from seedpass.core.entry_types import EntryType @pytest.fixture -def client(monkeypatch): +async def client(monkeypatch): dummy = SimpleNamespace( entry_manager=SimpleNamespace( search_entries=lambda q: [ @@ -45,27 +45,31 @@ def client(monkeypatch): monkeypatch.setattr(api, "PasswordManager", lambda: dummy) monkeypatch.setenv("SEEDPASS_CORS_ORIGINS", "http://example.com") token = api.start_server() - client = TestClient(api.app) - return client, token + transport = ASGITransport(app=api.app) + async with AsyncClient(transport=transport, base_url="http://test") as ac: + yield ac, token -def test_token_hashed(client): +@pytest.mark.anyio +async def test_token_hashed(client): _, token = client assert api.app.state.token_hash != token assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest() -def test_cors_and_auth(client): +@pytest.mark.anyio +async def test_cors_and_auth(client): cl, token = client headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} - res = cl.get("/api/v1/entry", params={"query": "s"}, headers=headers) + res = await cl.get("/api/v1/entry", params={"query": "s"}, headers=headers) assert res.status_code == 200 assert res.headers.get("access-control-allow-origin") == "http://example.com" -def test_invalid_token(client): +@pytest.mark.anyio +async def test_invalid_token(client): cl, _token = client - res = cl.get( + res = await cl.get( "/api/v1/entry", params={"query": "s"}, headers={"Authorization": "Bearer bad"}, @@ -73,60 +77,65 @@ def test_invalid_token(client): assert res.status_code == 401 -def test_get_entry_by_id(client): +@pytest.mark.anyio +async def test_get_entry_by_id(client): cl, token = client headers = { "Authorization": f"Bearer {token}", "Origin": "http://example.com", "X-SeedPass-Password": "pw", } - res = cl.get("/api/v1/entry/1", headers=headers) + res = await cl.get("/api/v1/entry/1", headers=headers) assert res.status_code == 200 assert res.json() == {"label": "Site"} assert res.headers.get("access-control-allow-origin") == "http://example.com" -def test_get_config_value(client): +@pytest.mark.anyio +async def test_get_config_value(client): cl, token = client headers = { "Authorization": f"Bearer {token}", "Origin": "http://example.com", } - res = cl.get("/api/v1/config/k", headers=headers) + res = await cl.get("/api/v1/config/k", headers=headers) assert res.status_code == 200 assert res.json() == {"key": "k", "value": "v"} assert res.headers.get("access-control-allow-origin") == "http://example.com" -def test_list_fingerprint(client): +@pytest.mark.anyio +async def test_list_fingerprint(client): cl, token = client headers = { "Authorization": f"Bearer {token}", "Origin": "http://example.com", } - res = cl.get("/api/v1/fingerprint", headers=headers) + res = await cl.get("/api/v1/fingerprint", headers=headers) assert res.status_code == 200 assert res.json() == ["fp"] assert res.headers.get("access-control-allow-origin") == "http://example.com" -def test_get_nostr_pubkey(client): +@pytest.mark.anyio +async def test_get_nostr_pubkey(client): cl, token = client headers = { "Authorization": f"Bearer {token}", "Origin": "http://example.com", } - res = cl.get("/api/v1/nostr/pubkey", headers=headers) + res = await cl.get("/api/v1/nostr/pubkey", headers=headers) assert res.status_code == 200 assert res.json() == {"npub": "np"} assert res.headers.get("access-control-allow-origin") == "http://example.com" -def test_create_modify_archive_entry(client): +@pytest.mark.anyio +async def test_create_modify_archive_entry(client): cl, token = client headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} - res = cl.post( + res = await cl.post( "/api/v1/entry", json={"label": "test", "length": 12}, headers=headers, @@ -134,7 +143,7 @@ def test_create_modify_archive_entry(client): assert res.status_code == 200 assert res.json() == {"id": 1} - res = cl.put( + res = await cl.put( "/api/v1/entry/1", json={"username": "bob"}, headers=headers, @@ -142,16 +151,17 @@ def test_create_modify_archive_entry(client): assert res.status_code == 200 assert res.json() == {"status": "ok"} - res = cl.post("/api/v1/entry/1/archive", headers=headers) + res = await cl.post("/api/v1/entry/1/archive", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "archived"} - res = cl.post("/api/v1/entry/1/unarchive", headers=headers) + res = await cl.post("/api/v1/entry/1/unarchive", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "active"} -def test_update_config(client): +@pytest.mark.anyio +async def test_update_config(client): cl, token = client called = {} @@ -160,7 +170,7 @@ def test_update_config(client): api.app.state.pm.config_manager.set_inactivity_timeout = set_timeout headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} - res = cl.put( + res = await cl.put( "/api/v1/config/inactivity_timeout", json={"value": 42}, headers=headers, @@ -171,14 +181,15 @@ def test_update_config(client): assert res.headers.get("access-control-allow-origin") == "http://example.com" -def test_update_config_quick_unlock(client): +@pytest.mark.anyio +async def test_update_config_quick_unlock(client): cl, token = client called = {} api.app.state.pm.config_manager.set_quick_unlock = lambda v: called.setdefault( "val", v ) headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} - res = cl.put( + res = await cl.put( "/api/v1/config/quick_unlock", json={"value": True}, headers=headers, @@ -188,12 +199,13 @@ def test_update_config_quick_unlock(client): assert called.get("val") is True -def test_change_password_route(client): +@pytest.mark.anyio +async def test_change_password_route(client): cl, token = client called = {} api.app.state.pm.change_password = lambda o, n: called.setdefault("called", (o, n)) headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} - res = cl.post( + res = await cl.post( "/api/v1/change-password", headers=headers, json={"old": "old", "new": "new"}, @@ -204,10 +216,11 @@ def test_change_password_route(client): assert res.headers.get("access-control-allow-origin") == "http://example.com" -def test_update_config_unknown_key(client): +@pytest.mark.anyio +async def test_update_config_unknown_key(client): cl, token = client headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} - res = cl.put( + res = await cl.put( "/api/v1/config/bogus", json={"value": 1}, headers=headers, @@ -215,7 +228,8 @@ def test_update_config_unknown_key(client): assert res.status_code == 400 -def test_shutdown(client, monkeypatch): +@pytest.mark.anyio +async def test_shutdown(client, monkeypatch): cl, token = client calls = {} @@ -231,7 +245,7 @@ def test_shutdown(client, monkeypatch): "Authorization": f"Bearer {token}", "Origin": "http://example.com", } - res = cl.post("/api/v1/shutdown", headers=headers) + res = await cl.post("/api/v1/shutdown", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "shutting down"} assert calls["func"] is sys.exit @@ -239,6 +253,7 @@ def test_shutdown(client, monkeypatch): assert res.headers.get("access-control-allow-origin") == "http://example.com" +@pytest.mark.anyio @pytest.mark.parametrize( "method,path", [ @@ -256,11 +271,11 @@ def test_shutdown(client, monkeypatch): ("post", "/api/v1/vault/lock"), ], ) -def test_invalid_token_other_endpoints(client, method, path): +async def test_invalid_token_other_endpoints(client, method, path): cl, _token = client req = getattr(cl, method) kwargs = {"headers": {"Authorization": "Bearer bad"}} if method in {"post", "put"}: kwargs["json"] = {} - res = req(path, **kwargs) + res = await req(path, **kwargs) assert res.status_code == 401 diff --git a/src/tests/test_api_new_endpoints.py b/src/tests/test_api_new_endpoints.py index df5f29f..c3e2de1 100644 --- a/src/tests/test_api_new_endpoints.py +++ b/src/tests/test_api_new_endpoints.py @@ -13,7 +13,8 @@ from seedpass.core.encryption import EncryptionManager from nostr.client import NostrClient, DEFAULT_RELAYS -def test_create_and_modify_totp_entry(client): +@pytest.mark.anyio +async def test_create_and_modify_totp_entry(client): cl, token = client calls = {} @@ -30,7 +31,7 @@ def test_create_and_modify_totp_entry(client): api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}"} - res = cl.post( + res = await cl.post( "/api/v1/entry", json={ "type": "totp", @@ -54,7 +55,7 @@ def test_create_and_modify_totp_entry(client): "archived": False, } - res = cl.put( + res = await cl.put( "/api/v1/entry/5", json={"period": 90, "digits": 6}, headers=headers, @@ -65,7 +66,8 @@ def test_create_and_modify_totp_entry(client): assert calls["modify"][1]["digits"] == 6 -def test_create_and_modify_ssh_entry(client): +@pytest.mark.anyio +async def test_create_and_modify_ssh_entry(client): cl, token = client calls = {} @@ -81,7 +83,7 @@ def test_create_and_modify_ssh_entry(client): api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}"} - res = cl.post( + res = await cl.post( "/api/v1/entry", json={"type": "ssh", "label": "S", "index": 2, "notes": "n"}, headers=headers, @@ -90,7 +92,7 @@ def test_create_and_modify_ssh_entry(client): assert res.json() == {"id": 2} assert calls["create"] == {"index": 2, "notes": "n", "archived": False} - res = cl.put( + res = await cl.put( "/api/v1/entry/2", json={"notes": "x"}, headers=headers, @@ -100,7 +102,8 @@ def test_create_and_modify_ssh_entry(client): assert calls["modify"][1]["notes"] == "x" -def test_update_entry_error(client): +@pytest.mark.anyio +async def test_update_entry_error(client): cl, token = client def modify(*a, **k): @@ -108,12 +111,13 @@ def test_update_entry_error(client): api.app.state.pm.entry_manager.modify_entry = modify headers = {"Authorization": f"Bearer {token}"} - res = cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers) + res = await cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers) assert res.status_code == 400 assert res.json() == {"detail": "nope"} -def test_update_config_secret_mode(client): +@pytest.mark.anyio +async def test_update_config_secret_mode(client): cl, token = client called = {} @@ -122,7 +126,7 @@ def test_update_config_secret_mode(client): api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret headers = {"Authorization": f"Bearer {token}"} - res = cl.put( + res = await cl.put( "/api/v1/config/secret_mode_enabled", json={"value": True}, headers=headers, @@ -132,17 +136,19 @@ def test_update_config_secret_mode(client): assert called["val"] is True -def test_totp_export_endpoint(client): +@pytest.mark.anyio +async def test_totp_export_endpoint(client): cl, token = client api.app.state.pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]} api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"} - res = cl.get("/api/v1/totp/export", headers=headers) + res = await cl.get("/api/v1/totp/export", headers=headers) assert res.status_code == 200 assert res.json() == {"entries": ["x"]} -def test_totp_codes_endpoint(client): +@pytest.mark.anyio +async def test_totp_codes_endpoint(client): cl, token = client api.app.state.pm.entry_manager.list_entries = lambda **kw: [ (0, "Email", None, None, False) @@ -151,7 +157,7 @@ def test_totp_codes_endpoint(client): api.app.state.pm.entry_manager.get_totp_time_remaining = lambda i: 30 api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"} - res = cl.get("/api/v1/totp", headers=headers) + res = await cl.get("/api/v1/totp", headers=headers) assert res.status_code == 200 assert res.json() == { "codes": [ @@ -160,13 +166,17 @@ def test_totp_codes_endpoint(client): } -def test_parent_seed_endpoint_removed(client): +@pytest.mark.anyio +async def test_parent_seed_endpoint_removed(client): cl, token = client - res = cl.get("/api/v1/parent-seed", headers={"Authorization": f"Bearer {token}"}) + res = await cl.get( + "/api/v1/parent-seed", headers={"Authorization": f"Bearer {token}"} + ) assert res.status_code == 404 -def test_fingerprint_endpoints(client): +@pytest.mark.anyio +async def test_fingerprint_endpoints(client): cl, token = client calls = {} @@ -178,17 +188,17 @@ def test_fingerprint_endpoints(client): headers = {"Authorization": f"Bearer {token}"} - res = cl.post("/api/v1/fingerprint", headers=headers) + res = await cl.post("/api/v1/fingerprint", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "ok"} assert calls.get("add") is True - res = cl.delete("/api/v1/fingerprint/abc", headers=headers) + res = await cl.delete("/api/v1/fingerprint/abc", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "deleted"} assert calls.get("remove") == "abc" - res = cl.post( + res = await cl.post( "/api/v1/fingerprint/select", json={"fingerprint": "xyz"}, headers=headers, @@ -198,7 +208,8 @@ def test_fingerprint_endpoints(client): assert calls.get("select") == "xyz" -def test_checksum_endpoints(client): +@pytest.mark.anyio +async def test_checksum_endpoints(client): cl, token = client calls = {} @@ -209,18 +220,19 @@ def test_checksum_endpoints(client): headers = {"Authorization": f"Bearer {token}"} - res = cl.post("/api/v1/checksum/verify", headers=headers) + res = await cl.post("/api/v1/checksum/verify", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "ok"} assert calls.get("verify") is True - res = cl.post("/api/v1/checksum/update", headers=headers) + res = await cl.post("/api/v1/checksum/update", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "ok"} assert calls.get("update") is True -def test_vault_import_via_path(client, tmp_path): +@pytest.mark.anyio +async def test_vault_import_via_path(client, tmp_path): cl, token = client called = {} @@ -236,7 +248,7 @@ def test_vault_import_via_path(client, tmp_path): file_path.write_text("{}") headers = {"Authorization": f"Bearer {token}"} - res = cl.post( + res = await cl.post( "/api/v1/vault/import", json={"path": str(file_path)}, headers=headers, @@ -247,7 +259,8 @@ def test_vault_import_via_path(client, tmp_path): assert called.get("sync") is True -def test_vault_import_via_upload(client, tmp_path): +@pytest.mark.anyio +async def test_vault_import_via_upload(client, tmp_path): cl, token = client called = {} @@ -261,7 +274,7 @@ def test_vault_import_via_upload(client, tmp_path): headers = {"Authorization": f"Bearer {token}"} with open(file_path, "rb") as fh: - res = cl.post( + res = await cl.post( "/api/v1/vault/import", files={"file": ("c.json", fh.read())}, headers=headers, @@ -272,7 +285,8 @@ def test_vault_import_via_upload(client, tmp_path): assert called.get("sync") is True -def test_vault_import_invalid_extension(client): +@pytest.mark.anyio +async def test_vault_import_invalid_extension(client): cl, token = client api.app.state.pm.handle_import_database = lambda path: None api.app.state.pm.sync_vault = lambda: None @@ -281,7 +295,7 @@ def test_vault_import_invalid_extension(client): ) headers = {"Authorization": f"Bearer {token}"} - res = cl.post( + res = await cl.post( "/api/v1/vault/import", json={"path": "bad.txt"}, headers=headers, @@ -289,7 +303,8 @@ def test_vault_import_invalid_extension(client): assert res.status_code == 400 -def test_vault_import_path_traversal_blocked(client, tmp_path): +@pytest.mark.anyio +async def test_vault_import_path_traversal_blocked(client, tmp_path): cl, token = client key = base64.urlsafe_b64encode(os.urandom(32)) api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path) @@ -297,7 +312,7 @@ def test_vault_import_path_traversal_blocked(client, tmp_path): api.app.state.pm.sync_vault = lambda: None headers = {"Authorization": f"Bearer {token}"} - res = cl.post( + res = await cl.post( "/api/v1/vault/import", json={"path": "../evil.json.enc"}, headers=headers, @@ -305,7 +320,8 @@ def test_vault_import_path_traversal_blocked(client, tmp_path): assert res.status_code == 400 -def test_vault_lock_endpoint(client): +@pytest.mark.anyio +async def test_vault_lock_endpoint(client): cl, token = client called = {} @@ -317,7 +333,7 @@ def test_vault_lock_endpoint(client): api.app.state.pm.locked = False headers = {"Authorization": f"Bearer {token}"} - res = cl.post("/api/v1/vault/lock", headers=headers) + res = await cl.post("/api/v1/vault/lock", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "locked"} assert called.get("locked") is True @@ -329,7 +345,8 @@ def test_vault_lock_endpoint(client): assert api.app.state.pm.locked is False -def test_secret_mode_endpoint(client): +@pytest.mark.anyio +async def test_secret_mode_endpoint(client): cl, token = client called = {} @@ -343,7 +360,7 @@ def test_secret_mode_endpoint(client): api.app.state.pm.config_manager.set_clipboard_clear_delay = set_delay headers = {"Authorization": f"Bearer {token}"} - res = cl.post( + res = await cl.post( "/api/v1/secret-mode", json={"enabled": True, "delay": 12}, headers=headers, @@ -354,7 +371,8 @@ def test_secret_mode_endpoint(client): assert called["delay"] == 12 -def test_vault_export_endpoint(client, tmp_path): +@pytest.mark.anyio +async def test_vault_export_endpoint(client, tmp_path): cl, token = client out = tmp_path / "out.json" out.write_text("data") @@ -365,15 +383,18 @@ def test_vault_export_endpoint(client, tmp_path): "Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw", } - res = cl.post("/api/v1/vault/export", headers=headers) + res = await cl.post("/api/v1/vault/export", headers=headers) assert res.status_code == 200 assert res.content == b"data" - res = cl.post("/api/v1/vault/export", headers={"Authorization": f"Bearer {token}"}) + res = await cl.post( + "/api/v1/vault/export", headers={"Authorization": f"Bearer {token}"} + ) assert res.status_code == 401 -def test_backup_parent_seed_endpoint(client, tmp_path): +@pytest.mark.anyio +async def test_backup_parent_seed_endpoint(client, tmp_path): cl, token = client api.app.state.pm.parent_seed = "seed" called = {} @@ -386,7 +407,7 @@ def test_backup_parent_seed_endpoint(client, tmp_path): "Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw", } - res = cl.post( + res = await cl.post( "/api/v1/vault/backup-parent-seed", json={"path": str(path), "confirm": True}, headers=headers, @@ -395,7 +416,7 @@ def test_backup_parent_seed_endpoint(client, tmp_path): assert res.json() == {"status": "saved", "path": str(path)} assert called["path"] == path - res = cl.post( + res = await cl.post( "/api/v1/vault/backup-parent-seed", json={"path": str(path)}, headers=headers, @@ -403,7 +424,8 @@ def test_backup_parent_seed_endpoint(client, tmp_path): assert res.status_code == 400 -def test_backup_parent_seed_path_traversal_blocked(client, tmp_path): +@pytest.mark.anyio +async def test_backup_parent_seed_path_traversal_blocked(client, tmp_path): cl, token = client api.app.state.pm.parent_seed = "seed" key = base64.urlsafe_b64encode(os.urandom(32)) @@ -412,7 +434,7 @@ def test_backup_parent_seed_path_traversal_blocked(client, tmp_path): "Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw", } - res = cl.post( + res = await cl.post( "/api/v1/vault/backup-parent-seed", json={"path": "../evil.enc", "confirm": True}, headers=headers, @@ -420,7 +442,8 @@ def test_backup_parent_seed_path_traversal_blocked(client, tmp_path): assert res.status_code == 400 -def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch): +@pytest.mark.anyio +async def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch): cl, token = client nostr_client, _ = dummy_nostr_client relays = ["wss://a", "wss://b"] @@ -448,28 +471,29 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch): headers = {"Authorization": f"Bearer {token}"} - res = cl.get("/api/v1/relays", headers=headers) + res = await cl.get("/api/v1/relays", headers=headers) assert res.status_code == 200 assert res.json() == {"relays": relays} - res = cl.post("/api/v1/relays", json={"url": "wss://c"}, headers=headers) + res = await cl.post("/api/v1/relays", json={"url": "wss://c"}, headers=headers) assert res.status_code == 200 assert called["set"] == ["wss://a", "wss://b", "wss://c"] api.app.state.pm.config_manager.load_config = lambda require_pin=False: { "relays": ["wss://a", "wss://b", "wss://c"] } - res = cl.delete("/api/v1/relays/2", headers=headers) + res = await cl.delete("/api/v1/relays/2", headers=headers) assert res.status_code == 200 assert called["set"] == ["wss://a", "wss://c"] - res = cl.post("/api/v1/relays/reset", headers=headers) + res = await cl.post("/api/v1/relays/reset", headers=headers) assert res.status_code == 200 assert called.get("init") is True assert api.app.state.pm.nostr_client.relays == list(DEFAULT_RELAYS) -def test_generate_password_no_special_chars(client): +@pytest.mark.anyio +async def test_generate_password_no_special_chars(client): cl, token = client class DummyEnc: @@ -486,7 +510,7 @@ def test_generate_password_no_special_chars(client): api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}"} - res = cl.post( + res = await cl.post( "/api/v1/password", json={"length": 16, "include_special_chars": False}, headers=headers, @@ -496,7 +520,8 @@ def test_generate_password_no_special_chars(client): assert not any(c in string.punctuation for c in pw) -def test_generate_password_allowed_chars(client): +@pytest.mark.anyio +async def test_generate_password_allowed_chars(client): cl, token = client class DummyEnc: @@ -514,7 +539,7 @@ def test_generate_password_allowed_chars(client): headers = {"Authorization": f"Bearer {token}"} allowed = "@$" - res = cl.post( + res = await cl.post( "/api/v1/password", json={"length": 16, "allowed_special_chars": allowed}, headers=headers, diff --git a/src/tests/test_api_notifications.py b/src/tests/test_api_notifications.py index aefbd7d..57a02a5 100644 --- a/src/tests/test_api_notifications.py +++ b/src/tests/test_api_notifications.py @@ -1,15 +1,19 @@ from test_api import client from types import SimpleNamespace import queue +import pytest import seedpass.api as api -def test_notifications_endpoint(client): +@pytest.mark.anyio +async def test_notifications_endpoint(client): cl, token = client api.app.state.pm.notifications = queue.Queue() api.app.state.pm.notifications.put(SimpleNamespace(message="m1", level="INFO")) api.app.state.pm.notifications.put(SimpleNamespace(message="m2", level="WARNING")) - res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) + res = await cl.get( + "/api/v1/notifications", headers={"Authorization": f"Bearer {token}"} + ) assert res.status_code == 200 assert res.json() == [ {"level": "INFO", "message": "m1"}, @@ -18,19 +22,25 @@ def test_notifications_endpoint(client): assert api.app.state.pm.notifications.empty() -def test_notifications_endpoint_clears_queue(client): +@pytest.mark.anyio +async def test_notifications_endpoint_clears_queue(client): cl, token = client api.app.state.pm.notifications = queue.Queue() api.app.state.pm.notifications.put(SimpleNamespace(message="hi", level="INFO")) - res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) + res = await cl.get( + "/api/v1/notifications", headers={"Authorization": f"Bearer {token}"} + ) assert res.status_code == 200 assert res.json() == [{"level": "INFO", "message": "hi"}] assert api.app.state.pm.notifications.empty() - res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) + res = await cl.get( + "/api/v1/notifications", headers={"Authorization": f"Bearer {token}"} + ) assert res.json() == [] -def test_notifications_endpoint_does_not_clear_current(client): +@pytest.mark.anyio +async def test_notifications_endpoint_does_not_clear_current(client): cl, token = client api.app.state.pm.notifications = queue.Queue() msg = SimpleNamespace(message="keep", level="INFO") @@ -40,7 +50,9 @@ def test_notifications_endpoint_does_not_clear_current(client): lambda: api.app.state.pm._current_notification ) - res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) + res = await cl.get( + "/api/v1/notifications", headers={"Authorization": f"Bearer {token}"} + ) assert res.status_code == 200 assert res.json() == [{"level": "INFO", "message": "keep"}] assert api.app.state.pm.notifications.empty() diff --git a/src/tests/test_api_profile_stats.py b/src/tests/test_api_profile_stats.py index a3c62dc..2e172bd 100644 --- a/src/tests/test_api_profile_stats.py +++ b/src/tests/test_api_profile_stats.py @@ -1,13 +1,14 @@ from test_api import client +import pytest -def test_profile_stats_endpoint(client): +@pytest.mark.anyio +async def test_profile_stats_endpoint(client): cl, token = client stats = {"total_entries": 1} - # monkeypatch set _pm.get_profile_stats after client fixture started import seedpass.api as api api.app.state.pm.get_profile_stats = lambda: stats - res = cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"}) + res = await cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"}) assert res.status_code == 200 assert res.json() == stats diff --git a/src/tests/test_api_rate_limit.py b/src/tests/test_api_rate_limit.py index ca42732..8f41d9e 100644 --- a/src/tests/test_api_rate_limit.py +++ b/src/tests/test_api_rate_limit.py @@ -1,15 +1,17 @@ import importlib from pathlib import Path from types import SimpleNamespace - -from fastapi.testclient import TestClient +import importlib +import pytest +from httpx import ASGITransport, AsyncClient import sys sys.path.append(str(Path(__file__).resolve().parents[1])) -def test_rate_limit_exceeded(monkeypatch): +@pytest.mark.anyio +async def test_rate_limit_exceeded(monkeypatch): monkeypatch.setenv("SEEDPASS_RATE_LIMIT", "2") monkeypatch.setenv("SEEDPASS_RATE_WINDOW", "60") import seedpass.api as api @@ -31,12 +33,15 @@ def test_rate_limit_exceeded(monkeypatch): ) monkeypatch.setattr(api, "PasswordManager", lambda: dummy) token = api.start_server() - client = TestClient(api.app) - headers = {"Authorization": f"Bearer {token}"} + transport = ASGITransport(app=api.app) + async with AsyncClient(transport=transport, base_url="http://test") as client: + headers = {"Authorization": f"Bearer {token}"} - for _ in range(2): - res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers) - assert res.status_code == 200 + for _ in range(2): + res = await client.get( + "/api/v1/entry", params={"query": "s"}, headers=headers + ) + assert res.status_code == 200 - res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers) - assert res.status_code == 429 + res = await client.get("/api/v1/entry", params={"query": "s"}, headers=headers) + assert res.status_code == 429