Merge pull request #771 from PR0M3TH3AN/codex/convert-api-functions-to-async

Refactor API endpoints to async
This commit is contained in:
thePR0M3TH3AN
2025-08-05 20:24:01 -04:00
committed by GitHub
7 changed files with 193 additions and 119 deletions

View File

@@ -14,6 +14,7 @@ import jwt
import logging import logging
from fastapi import FastAPI, Header, HTTPException, Request, Response from fastapi import FastAPI, Header, HTTPException, Request, Response
from fastapi.concurrency import run_in_threadpool
import asyncio import asyncio
import sys import sys
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
@@ -132,12 +133,12 @@ def _validate_encryption_path(request: Request, path: Path) -> Path:
@app.get("/api/v1/entry") @app.get("/api/v1/entry")
def search_entry( async def search_entry(
request: Request, query: str, authorization: str | None = Header(None) request: Request, query: str, authorization: str | None = Header(None)
) -> List[Any]: ) -> List[Any]:
_check_token(request, authorization) _check_token(request, authorization)
pm = _get_pm(request) pm = _get_pm(request)
results = pm.entry_manager.search_entries(query) results = await run_in_threadpool(pm.entry_manager.search_entries, query)
return [ return [
{ {
"id": idx, "id": idx,
@@ -152,7 +153,7 @@ def search_entry(
@app.get("/api/v1/entry/{entry_id}") @app.get("/api/v1/entry/{entry_id}")
def get_entry( async def get_entry(
request: Request, request: Request,
entry_id: int, entry_id: int,
authorization: str | None = Header(None), authorization: str | None = Header(None),
@@ -161,14 +162,14 @@ def get_entry(
_check_token(request, authorization) _check_token(request, authorization)
_require_password(request, password) _require_password(request, password)
pm = _get_pm(request) pm = _get_pm(request)
entry = pm.entry_manager.retrieve_entry(entry_id) entry = await run_in_threadpool(pm.entry_manager.retrieve_entry, entry_id)
if entry is None: if entry is None:
raise HTTPException(status_code=404, detail="Not found") raise HTTPException(status_code=404, detail="Not found")
return entry return entry
@app.post("/api/v1/entry") @app.post("/api/v1/entry")
def create_entry( async def create_entry(
request: Request, request: Request,
entry: dict, entry: dict,
authorization: str | None = Header(None), authorization: str | None = Header(None),
@@ -197,7 +198,8 @@ def create_entry(
] ]
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None} kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
index = pm.entry_manager.add_entry( index = await run_in_threadpool(
pm.entry_manager.add_entry,
entry.get("label"), entry.get("label"),
int(entry.get("length", 12)), int(entry.get("length", 12)),
entry.get("username"), entry.get("username"),
@@ -207,9 +209,10 @@ def create_entry(
return {"id": index} return {"id": index}
if etype == "totp": if etype == "totp":
index = pm.entry_manager.get_next_index() index = await run_in_threadpool(pm.entry_manager.get_next_index)
uri = pm.entry_manager.add_totp( uri = await run_in_threadpool(
pm.entry_manager.add_totp,
entry.get("label"), entry.get("label"),
pm.parent_seed, pm.parent_seed,
secret=entry.get("secret"), secret=entry.get("secret"),
@@ -222,7 +225,8 @@ def create_entry(
return {"id": index, "uri": uri} return {"id": index, "uri": uri}
if etype == "ssh": if etype == "ssh":
index = pm.entry_manager.add_ssh_key( index = await run_in_threadpool(
pm.entry_manager.add_ssh_key,
entry.get("label"), entry.get("label"),
pm.parent_seed, pm.parent_seed,
index=entry.get("index"), index=entry.get("index"),
@@ -232,7 +236,8 @@ def create_entry(
return {"id": index} return {"id": index}
if etype == "pgp": if etype == "pgp":
index = pm.entry_manager.add_pgp_key( index = await run_in_threadpool(
pm.entry_manager.add_pgp_key,
entry.get("label"), entry.get("label"),
pm.parent_seed, pm.parent_seed,
index=entry.get("index"), index=entry.get("index"),
@@ -244,7 +249,8 @@ def create_entry(
return {"id": index} return {"id": index}
if etype == "nostr": if etype == "nostr":
index = pm.entry_manager.add_nostr_key( index = await run_in_threadpool(
pm.entry_manager.add_nostr_key,
entry.get("label"), entry.get("label"),
pm.parent_seed, pm.parent_seed,
index=entry.get("index"), index=entry.get("index"),
@@ -254,7 +260,8 @@ def create_entry(
return {"id": index} return {"id": index}
if etype == "key_value": if etype == "key_value":
index = pm.entry_manager.add_key_value( index = await run_in_threadpool(
pm.entry_manager.add_key_value,
entry.get("label"), entry.get("label"),
entry.get("key"), entry.get("key"),
entry.get("value"), entry.get("value"),
@@ -268,7 +275,8 @@ def create_entry(
if etype == "seed" if etype == "seed"
else pm.entry_manager.add_managed_account else pm.entry_manager.add_managed_account
) )
index = func( index = await run_in_threadpool(
func,
entry.get("label"), entry.get("label"),
pm.parent_seed, pm.parent_seed,
index=entry.get("index"), index=entry.get("index"),

View File

@@ -1,7 +1,15 @@
import importlib.util
import logging import logging
import pytest import pytest
@pytest.fixture(
params=["asyncio"] + (["trio"] if importlib.util.find_spec("trio") else [])
)
def anyio_backend(request):
return request.param
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def mute_logging(): def mute_logging():
logging.getLogger().setLevel(logging.WARNING) logging.getLogger().setLevel(logging.WARNING)

View File

@@ -3,7 +3,7 @@ from pathlib import Path
import sys import sys
import pytest import pytest
from fastapi.testclient import TestClient from httpx import ASGITransport, AsyncClient
import hashlib import hashlib
sys.path.append(str(Path(__file__).resolve().parents[1])) sys.path.append(str(Path(__file__).resolve().parents[1]))
@@ -13,7 +13,7 @@ from seedpass.core.entry_types import EntryType
@pytest.fixture @pytest.fixture
def client(monkeypatch): async def client(monkeypatch):
dummy = SimpleNamespace( dummy = SimpleNamespace(
entry_manager=SimpleNamespace( entry_manager=SimpleNamespace(
search_entries=lambda q: [ search_entries=lambda q: [
@@ -45,27 +45,31 @@ def client(monkeypatch):
monkeypatch.setattr(api, "PasswordManager", lambda: dummy) monkeypatch.setattr(api, "PasswordManager", lambda: dummy)
monkeypatch.setenv("SEEDPASS_CORS_ORIGINS", "http://example.com") monkeypatch.setenv("SEEDPASS_CORS_ORIGINS", "http://example.com")
token = api.start_server() token = api.start_server()
client = TestClient(api.app) transport = ASGITransport(app=api.app)
return client, token async with AsyncClient(transport=transport, base_url="http://test") as ac:
yield ac, token
def test_token_hashed(client): @pytest.mark.anyio
async def test_token_hashed(client):
_, token = client _, token = client
assert api.app.state.token_hash != token assert api.app.state.token_hash != token
assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest() assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest()
def test_cors_and_auth(client): @pytest.mark.anyio
async def test_cors_and_auth(client):
cl, token = client cl, token = client
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.get("/api/v1/entry", params={"query": "s"}, headers=headers) res = await cl.get("/api/v1/entry", params={"query": "s"}, headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.headers.get("access-control-allow-origin") == "http://example.com" assert res.headers.get("access-control-allow-origin") == "http://example.com"
def test_invalid_token(client): @pytest.mark.anyio
async def test_invalid_token(client):
cl, _token = client cl, _token = client
res = cl.get( res = await cl.get(
"/api/v1/entry", "/api/v1/entry",
params={"query": "s"}, params={"query": "s"},
headers={"Authorization": "Bearer bad"}, headers={"Authorization": "Bearer bad"},
@@ -73,60 +77,65 @@ def test_invalid_token(client):
assert res.status_code == 401 assert res.status_code == 401
def test_get_entry_by_id(client): @pytest.mark.anyio
async def test_get_entry_by_id(client):
cl, token = client cl, token = client
headers = { headers = {
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"Origin": "http://example.com", "Origin": "http://example.com",
"X-SeedPass-Password": "pw", "X-SeedPass-Password": "pw",
} }
res = cl.get("/api/v1/entry/1", headers=headers) res = await cl.get("/api/v1/entry/1", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"label": "Site"} assert res.json() == {"label": "Site"}
assert res.headers.get("access-control-allow-origin") == "http://example.com" assert res.headers.get("access-control-allow-origin") == "http://example.com"
def test_get_config_value(client): @pytest.mark.anyio
async def test_get_config_value(client):
cl, token = client cl, token = client
headers = { headers = {
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"Origin": "http://example.com", "Origin": "http://example.com",
} }
res = cl.get("/api/v1/config/k", headers=headers) res = await cl.get("/api/v1/config/k", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"key": "k", "value": "v"} assert res.json() == {"key": "k", "value": "v"}
assert res.headers.get("access-control-allow-origin") == "http://example.com" assert res.headers.get("access-control-allow-origin") == "http://example.com"
def test_list_fingerprint(client): @pytest.mark.anyio
async def test_list_fingerprint(client):
cl, token = client cl, token = client
headers = { headers = {
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"Origin": "http://example.com", "Origin": "http://example.com",
} }
res = cl.get("/api/v1/fingerprint", headers=headers) res = await cl.get("/api/v1/fingerprint", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == ["fp"] assert res.json() == ["fp"]
assert res.headers.get("access-control-allow-origin") == "http://example.com" assert res.headers.get("access-control-allow-origin") == "http://example.com"
def test_get_nostr_pubkey(client): @pytest.mark.anyio
async def test_get_nostr_pubkey(client):
cl, token = client cl, token = client
headers = { headers = {
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"Origin": "http://example.com", "Origin": "http://example.com",
} }
res = cl.get("/api/v1/nostr/pubkey", headers=headers) res = await cl.get("/api/v1/nostr/pubkey", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"npub": "np"} assert res.json() == {"npub": "np"}
assert res.headers.get("access-control-allow-origin") == "http://example.com" assert res.headers.get("access-control-allow-origin") == "http://example.com"
def test_create_modify_archive_entry(client): @pytest.mark.anyio
async def test_create_modify_archive_entry(client):
cl, token = client cl, token = client
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.post( res = await cl.post(
"/api/v1/entry", "/api/v1/entry",
json={"label": "test", "length": 12}, json={"label": "test", "length": 12},
headers=headers, headers=headers,
@@ -134,7 +143,7 @@ def test_create_modify_archive_entry(client):
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"id": 1} assert res.json() == {"id": 1}
res = cl.put( res = await cl.put(
"/api/v1/entry/1", "/api/v1/entry/1",
json={"username": "bob"}, json={"username": "bob"},
headers=headers, headers=headers,
@@ -142,16 +151,17 @@ def test_create_modify_archive_entry(client):
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "ok"} assert res.json() == {"status": "ok"}
res = cl.post("/api/v1/entry/1/archive", headers=headers) res = await cl.post("/api/v1/entry/1/archive", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "archived"} assert res.json() == {"status": "archived"}
res = cl.post("/api/v1/entry/1/unarchive", headers=headers) res = await cl.post("/api/v1/entry/1/unarchive", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "active"} assert res.json() == {"status": "active"}
def test_update_config(client): @pytest.mark.anyio
async def test_update_config(client):
cl, token = client cl, token = client
called = {} called = {}
@@ -160,7 +170,7 @@ def test_update_config(client):
api.app.state.pm.config_manager.set_inactivity_timeout = set_timeout api.app.state.pm.config_manager.set_inactivity_timeout = set_timeout
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.put( res = await cl.put(
"/api/v1/config/inactivity_timeout", "/api/v1/config/inactivity_timeout",
json={"value": 42}, json={"value": 42},
headers=headers, headers=headers,
@@ -171,14 +181,15 @@ def test_update_config(client):
assert res.headers.get("access-control-allow-origin") == "http://example.com" assert res.headers.get("access-control-allow-origin") == "http://example.com"
def test_update_config_quick_unlock(client): @pytest.mark.anyio
async def test_update_config_quick_unlock(client):
cl, token = client cl, token = client
called = {} called = {}
api.app.state.pm.config_manager.set_quick_unlock = lambda v: called.setdefault( api.app.state.pm.config_manager.set_quick_unlock = lambda v: called.setdefault(
"val", v "val", v
) )
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.put( res = await cl.put(
"/api/v1/config/quick_unlock", "/api/v1/config/quick_unlock",
json={"value": True}, json={"value": True},
headers=headers, headers=headers,
@@ -188,12 +199,13 @@ def test_update_config_quick_unlock(client):
assert called.get("val") is True assert called.get("val") is True
def test_change_password_route(client): @pytest.mark.anyio
async def test_change_password_route(client):
cl, token = client cl, token = client
called = {} called = {}
api.app.state.pm.change_password = lambda o, n: called.setdefault("called", (o, n)) api.app.state.pm.change_password = lambda o, n: called.setdefault("called", (o, n))
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.post( res = await cl.post(
"/api/v1/change-password", "/api/v1/change-password",
headers=headers, headers=headers,
json={"old": "old", "new": "new"}, json={"old": "old", "new": "new"},
@@ -204,10 +216,11 @@ def test_change_password_route(client):
assert res.headers.get("access-control-allow-origin") == "http://example.com" assert res.headers.get("access-control-allow-origin") == "http://example.com"
def test_update_config_unknown_key(client): @pytest.mark.anyio
async def test_update_config_unknown_key(client):
cl, token = client cl, token = client
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.put( res = await cl.put(
"/api/v1/config/bogus", "/api/v1/config/bogus",
json={"value": 1}, json={"value": 1},
headers=headers, headers=headers,
@@ -215,7 +228,8 @@ def test_update_config_unknown_key(client):
assert res.status_code == 400 assert res.status_code == 400
def test_shutdown(client, monkeypatch): @pytest.mark.anyio
async def test_shutdown(client, monkeypatch):
cl, token = client cl, token = client
calls = {} calls = {}
@@ -231,7 +245,7 @@ def test_shutdown(client, monkeypatch):
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"Origin": "http://example.com", "Origin": "http://example.com",
} }
res = cl.post("/api/v1/shutdown", headers=headers) res = await cl.post("/api/v1/shutdown", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "shutting down"} assert res.json() == {"status": "shutting down"}
assert calls["func"] is sys.exit assert calls["func"] is sys.exit
@@ -239,6 +253,7 @@ def test_shutdown(client, monkeypatch):
assert res.headers.get("access-control-allow-origin") == "http://example.com" assert res.headers.get("access-control-allow-origin") == "http://example.com"
@pytest.mark.anyio
@pytest.mark.parametrize( @pytest.mark.parametrize(
"method,path", "method,path",
[ [
@@ -256,11 +271,11 @@ def test_shutdown(client, monkeypatch):
("post", "/api/v1/vault/lock"), ("post", "/api/v1/vault/lock"),
], ],
) )
def test_invalid_token_other_endpoints(client, method, path): async def test_invalid_token_other_endpoints(client, method, path):
cl, _token = client cl, _token = client
req = getattr(cl, method) req = getattr(cl, method)
kwargs = {"headers": {"Authorization": "Bearer bad"}} kwargs = {"headers": {"Authorization": "Bearer bad"}}
if method in {"post", "put"}: if method in {"post", "put"}:
kwargs["json"] = {} kwargs["json"] = {}
res = req(path, **kwargs) res = await req(path, **kwargs)
assert res.status_code == 401 assert res.status_code == 401

View File

@@ -13,7 +13,8 @@ from seedpass.core.encryption import EncryptionManager
from nostr.client import NostrClient, DEFAULT_RELAYS from nostr.client import NostrClient, DEFAULT_RELAYS
def test_create_and_modify_totp_entry(client): @pytest.mark.anyio
async def test_create_and_modify_totp_entry(client):
cl, token = client cl, token = client
calls = {} calls = {}
@@ -30,7 +31,7 @@ def test_create_and_modify_totp_entry(client):
api.app.state.pm.parent_seed = "seed" api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post( res = await cl.post(
"/api/v1/entry", "/api/v1/entry",
json={ json={
"type": "totp", "type": "totp",
@@ -54,7 +55,7 @@ def test_create_and_modify_totp_entry(client):
"archived": False, "archived": False,
} }
res = cl.put( res = await cl.put(
"/api/v1/entry/5", "/api/v1/entry/5",
json={"period": 90, "digits": 6}, json={"period": 90, "digits": 6},
headers=headers, headers=headers,
@@ -65,7 +66,8 @@ def test_create_and_modify_totp_entry(client):
assert calls["modify"][1]["digits"] == 6 assert calls["modify"][1]["digits"] == 6
def test_create_and_modify_ssh_entry(client): @pytest.mark.anyio
async def test_create_and_modify_ssh_entry(client):
cl, token = client cl, token = client
calls = {} calls = {}
@@ -81,7 +83,7 @@ def test_create_and_modify_ssh_entry(client):
api.app.state.pm.parent_seed = "seed" api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post( res = await cl.post(
"/api/v1/entry", "/api/v1/entry",
json={"type": "ssh", "label": "S", "index": 2, "notes": "n"}, json={"type": "ssh", "label": "S", "index": 2, "notes": "n"},
headers=headers, headers=headers,
@@ -90,7 +92,7 @@ def test_create_and_modify_ssh_entry(client):
assert res.json() == {"id": 2} assert res.json() == {"id": 2}
assert calls["create"] == {"index": 2, "notes": "n", "archived": False} assert calls["create"] == {"index": 2, "notes": "n", "archived": False}
res = cl.put( res = await cl.put(
"/api/v1/entry/2", "/api/v1/entry/2",
json={"notes": "x"}, json={"notes": "x"},
headers=headers, headers=headers,
@@ -100,7 +102,8 @@ def test_create_and_modify_ssh_entry(client):
assert calls["modify"][1]["notes"] == "x" assert calls["modify"][1]["notes"] == "x"
def test_update_entry_error(client): @pytest.mark.anyio
async def test_update_entry_error(client):
cl, token = client cl, token = client
def modify(*a, **k): def modify(*a, **k):
@@ -108,12 +111,13 @@ def test_update_entry_error(client):
api.app.state.pm.entry_manager.modify_entry = modify api.app.state.pm.entry_manager.modify_entry = modify
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers) res = await cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers)
assert res.status_code == 400 assert res.status_code == 400
assert res.json() == {"detail": "nope"} assert res.json() == {"detail": "nope"}
def test_update_config_secret_mode(client): @pytest.mark.anyio
async def test_update_config_secret_mode(client):
cl, token = client cl, token = client
called = {} called = {}
@@ -122,7 +126,7 @@ def test_update_config_secret_mode(client):
api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.put( res = await cl.put(
"/api/v1/config/secret_mode_enabled", "/api/v1/config/secret_mode_enabled",
json={"value": True}, json={"value": True},
headers=headers, headers=headers,
@@ -132,17 +136,19 @@ def test_update_config_secret_mode(client):
assert called["val"] is True assert called["val"] is True
def test_totp_export_endpoint(client): @pytest.mark.anyio
async def test_totp_export_endpoint(client):
cl, token = client cl, token = client
api.app.state.pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]} api.app.state.pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]}
api.app.state.pm.parent_seed = "seed" api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"} headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"}
res = cl.get("/api/v1/totp/export", headers=headers) res = await cl.get("/api/v1/totp/export", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"entries": ["x"]} assert res.json() == {"entries": ["x"]}
def test_totp_codes_endpoint(client): @pytest.mark.anyio
async def test_totp_codes_endpoint(client):
cl, token = client cl, token = client
api.app.state.pm.entry_manager.list_entries = lambda **kw: [ api.app.state.pm.entry_manager.list_entries = lambda **kw: [
(0, "Email", None, None, False) (0, "Email", None, None, False)
@@ -151,7 +157,7 @@ def test_totp_codes_endpoint(client):
api.app.state.pm.entry_manager.get_totp_time_remaining = lambda i: 30 api.app.state.pm.entry_manager.get_totp_time_remaining = lambda i: 30
api.app.state.pm.parent_seed = "seed" api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"} headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"}
res = cl.get("/api/v1/totp", headers=headers) res = await cl.get("/api/v1/totp", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == { assert res.json() == {
"codes": [ "codes": [
@@ -160,13 +166,17 @@ def test_totp_codes_endpoint(client):
} }
def test_parent_seed_endpoint_removed(client): @pytest.mark.anyio
async def test_parent_seed_endpoint_removed(client):
cl, token = client cl, token = client
res = cl.get("/api/v1/parent-seed", headers={"Authorization": f"Bearer {token}"}) res = await cl.get(
"/api/v1/parent-seed", headers={"Authorization": f"Bearer {token}"}
)
assert res.status_code == 404 assert res.status_code == 404
def test_fingerprint_endpoints(client): @pytest.mark.anyio
async def test_fingerprint_endpoints(client):
cl, token = client cl, token = client
calls = {} calls = {}
@@ -178,17 +188,17 @@ def test_fingerprint_endpoints(client):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post("/api/v1/fingerprint", headers=headers) res = await cl.post("/api/v1/fingerprint", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "ok"} assert res.json() == {"status": "ok"}
assert calls.get("add") is True assert calls.get("add") is True
res = cl.delete("/api/v1/fingerprint/abc", headers=headers) res = await cl.delete("/api/v1/fingerprint/abc", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "deleted"} assert res.json() == {"status": "deleted"}
assert calls.get("remove") == "abc" assert calls.get("remove") == "abc"
res = cl.post( res = await cl.post(
"/api/v1/fingerprint/select", "/api/v1/fingerprint/select",
json={"fingerprint": "xyz"}, json={"fingerprint": "xyz"},
headers=headers, headers=headers,
@@ -198,7 +208,8 @@ def test_fingerprint_endpoints(client):
assert calls.get("select") == "xyz" assert calls.get("select") == "xyz"
def test_checksum_endpoints(client): @pytest.mark.anyio
async def test_checksum_endpoints(client):
cl, token = client cl, token = client
calls = {} calls = {}
@@ -209,18 +220,19 @@ def test_checksum_endpoints(client):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post("/api/v1/checksum/verify", headers=headers) res = await cl.post("/api/v1/checksum/verify", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "ok"} assert res.json() == {"status": "ok"}
assert calls.get("verify") is True assert calls.get("verify") is True
res = cl.post("/api/v1/checksum/update", headers=headers) res = await cl.post("/api/v1/checksum/update", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "ok"} assert res.json() == {"status": "ok"}
assert calls.get("update") is True assert calls.get("update") is True
def test_vault_import_via_path(client, tmp_path): @pytest.mark.anyio
async def test_vault_import_via_path(client, tmp_path):
cl, token = client cl, token = client
called = {} called = {}
@@ -236,7 +248,7 @@ def test_vault_import_via_path(client, tmp_path):
file_path.write_text("{}") file_path.write_text("{}")
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post( res = await cl.post(
"/api/v1/vault/import", "/api/v1/vault/import",
json={"path": str(file_path)}, json={"path": str(file_path)},
headers=headers, headers=headers,
@@ -247,7 +259,8 @@ def test_vault_import_via_path(client, tmp_path):
assert called.get("sync") is True assert called.get("sync") is True
def test_vault_import_via_upload(client, tmp_path): @pytest.mark.anyio
async def test_vault_import_via_upload(client, tmp_path):
cl, token = client cl, token = client
called = {} called = {}
@@ -261,7 +274,7 @@ def test_vault_import_via_upload(client, tmp_path):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
with open(file_path, "rb") as fh: with open(file_path, "rb") as fh:
res = cl.post( res = await cl.post(
"/api/v1/vault/import", "/api/v1/vault/import",
files={"file": ("c.json", fh.read())}, files={"file": ("c.json", fh.read())},
headers=headers, headers=headers,
@@ -272,7 +285,8 @@ def test_vault_import_via_upload(client, tmp_path):
assert called.get("sync") is True assert called.get("sync") is True
def test_vault_import_invalid_extension(client): @pytest.mark.anyio
async def test_vault_import_invalid_extension(client):
cl, token = client cl, token = client
api.app.state.pm.handle_import_database = lambda path: None api.app.state.pm.handle_import_database = lambda path: None
api.app.state.pm.sync_vault = lambda: None api.app.state.pm.sync_vault = lambda: None
@@ -281,7 +295,7 @@ def test_vault_import_invalid_extension(client):
) )
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post( res = await cl.post(
"/api/v1/vault/import", "/api/v1/vault/import",
json={"path": "bad.txt"}, json={"path": "bad.txt"},
headers=headers, headers=headers,
@@ -289,7 +303,8 @@ def test_vault_import_invalid_extension(client):
assert res.status_code == 400 assert res.status_code == 400
def test_vault_import_path_traversal_blocked(client, tmp_path): @pytest.mark.anyio
async def test_vault_import_path_traversal_blocked(client, tmp_path):
cl, token = client cl, token = client
key = base64.urlsafe_b64encode(os.urandom(32)) key = base64.urlsafe_b64encode(os.urandom(32))
api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path) api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path)
@@ -297,7 +312,7 @@ def test_vault_import_path_traversal_blocked(client, tmp_path):
api.app.state.pm.sync_vault = lambda: None api.app.state.pm.sync_vault = lambda: None
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post( res = await cl.post(
"/api/v1/vault/import", "/api/v1/vault/import",
json={"path": "../evil.json.enc"}, json={"path": "../evil.json.enc"},
headers=headers, headers=headers,
@@ -305,7 +320,8 @@ def test_vault_import_path_traversal_blocked(client, tmp_path):
assert res.status_code == 400 assert res.status_code == 400
def test_vault_lock_endpoint(client): @pytest.mark.anyio
async def test_vault_lock_endpoint(client):
cl, token = client cl, token = client
called = {} called = {}
@@ -317,7 +333,7 @@ def test_vault_lock_endpoint(client):
api.app.state.pm.locked = False api.app.state.pm.locked = False
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post("/api/v1/vault/lock", headers=headers) res = await cl.post("/api/v1/vault/lock", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"status": "locked"} assert res.json() == {"status": "locked"}
assert called.get("locked") is True assert called.get("locked") is True
@@ -329,7 +345,8 @@ def test_vault_lock_endpoint(client):
assert api.app.state.pm.locked is False assert api.app.state.pm.locked is False
def test_secret_mode_endpoint(client): @pytest.mark.anyio
async def test_secret_mode_endpoint(client):
cl, token = client cl, token = client
called = {} called = {}
@@ -343,7 +360,7 @@ def test_secret_mode_endpoint(client):
api.app.state.pm.config_manager.set_clipboard_clear_delay = set_delay api.app.state.pm.config_manager.set_clipboard_clear_delay = set_delay
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post( res = await cl.post(
"/api/v1/secret-mode", "/api/v1/secret-mode",
json={"enabled": True, "delay": 12}, json={"enabled": True, "delay": 12},
headers=headers, headers=headers,
@@ -354,7 +371,8 @@ def test_secret_mode_endpoint(client):
assert called["delay"] == 12 assert called["delay"] == 12
def test_vault_export_endpoint(client, tmp_path): @pytest.mark.anyio
async def test_vault_export_endpoint(client, tmp_path):
cl, token = client cl, token = client
out = tmp_path / "out.json" out = tmp_path / "out.json"
out.write_text("data") out.write_text("data")
@@ -365,15 +383,18 @@ def test_vault_export_endpoint(client, tmp_path):
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"X-SeedPass-Password": "pw", "X-SeedPass-Password": "pw",
} }
res = cl.post("/api/v1/vault/export", headers=headers) res = await cl.post("/api/v1/vault/export", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.content == b"data" assert res.content == b"data"
res = cl.post("/api/v1/vault/export", headers={"Authorization": f"Bearer {token}"}) res = await cl.post(
"/api/v1/vault/export", headers={"Authorization": f"Bearer {token}"}
)
assert res.status_code == 401 assert res.status_code == 401
def test_backup_parent_seed_endpoint(client, tmp_path): @pytest.mark.anyio
async def test_backup_parent_seed_endpoint(client, tmp_path):
cl, token = client cl, token = client
api.app.state.pm.parent_seed = "seed" api.app.state.pm.parent_seed = "seed"
called = {} called = {}
@@ -386,7 +407,7 @@ def test_backup_parent_seed_endpoint(client, tmp_path):
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"X-SeedPass-Password": "pw", "X-SeedPass-Password": "pw",
} }
res = cl.post( res = await cl.post(
"/api/v1/vault/backup-parent-seed", "/api/v1/vault/backup-parent-seed",
json={"path": str(path), "confirm": True}, json={"path": str(path), "confirm": True},
headers=headers, headers=headers,
@@ -395,7 +416,7 @@ def test_backup_parent_seed_endpoint(client, tmp_path):
assert res.json() == {"status": "saved", "path": str(path)} assert res.json() == {"status": "saved", "path": str(path)}
assert called["path"] == path assert called["path"] == path
res = cl.post( res = await cl.post(
"/api/v1/vault/backup-parent-seed", "/api/v1/vault/backup-parent-seed",
json={"path": str(path)}, json={"path": str(path)},
headers=headers, headers=headers,
@@ -403,7 +424,8 @@ def test_backup_parent_seed_endpoint(client, tmp_path):
assert res.status_code == 400 assert res.status_code == 400
def test_backup_parent_seed_path_traversal_blocked(client, tmp_path): @pytest.mark.anyio
async def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
cl, token = client cl, token = client
api.app.state.pm.parent_seed = "seed" api.app.state.pm.parent_seed = "seed"
key = base64.urlsafe_b64encode(os.urandom(32)) key = base64.urlsafe_b64encode(os.urandom(32))
@@ -412,7 +434,7 @@ def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
"Authorization": f"Bearer {token}", "Authorization": f"Bearer {token}",
"X-SeedPass-Password": "pw", "X-SeedPass-Password": "pw",
} }
res = cl.post( res = await cl.post(
"/api/v1/vault/backup-parent-seed", "/api/v1/vault/backup-parent-seed",
json={"path": "../evil.enc", "confirm": True}, json={"path": "../evil.enc", "confirm": True},
headers=headers, headers=headers,
@@ -420,7 +442,8 @@ def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
assert res.status_code == 400 assert res.status_code == 400
def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch): @pytest.mark.anyio
async def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
cl, token = client cl, token = client
nostr_client, _ = dummy_nostr_client nostr_client, _ = dummy_nostr_client
relays = ["wss://a", "wss://b"] relays = ["wss://a", "wss://b"]
@@ -448,28 +471,29 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.get("/api/v1/relays", headers=headers) res = await cl.get("/api/v1/relays", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == {"relays": relays} assert res.json() == {"relays": relays}
res = cl.post("/api/v1/relays", json={"url": "wss://c"}, headers=headers) res = await cl.post("/api/v1/relays", json={"url": "wss://c"}, headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert called["set"] == ["wss://a", "wss://b", "wss://c"] assert called["set"] == ["wss://a", "wss://b", "wss://c"]
api.app.state.pm.config_manager.load_config = lambda require_pin=False: { api.app.state.pm.config_manager.load_config = lambda require_pin=False: {
"relays": ["wss://a", "wss://b", "wss://c"] "relays": ["wss://a", "wss://b", "wss://c"]
} }
res = cl.delete("/api/v1/relays/2", headers=headers) res = await cl.delete("/api/v1/relays/2", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert called["set"] == ["wss://a", "wss://c"] assert called["set"] == ["wss://a", "wss://c"]
res = cl.post("/api/v1/relays/reset", headers=headers) res = await cl.post("/api/v1/relays/reset", headers=headers)
assert res.status_code == 200 assert res.status_code == 200
assert called.get("init") is True assert called.get("init") is True
assert api.app.state.pm.nostr_client.relays == list(DEFAULT_RELAYS) assert api.app.state.pm.nostr_client.relays == list(DEFAULT_RELAYS)
def test_generate_password_no_special_chars(client): @pytest.mark.anyio
async def test_generate_password_no_special_chars(client):
cl, token = client cl, token = client
class DummyEnc: class DummyEnc:
@@ -486,7 +510,7 @@ def test_generate_password_no_special_chars(client):
api.app.state.pm.parent_seed = "seed" api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
res = cl.post( res = await cl.post(
"/api/v1/password", "/api/v1/password",
json={"length": 16, "include_special_chars": False}, json={"length": 16, "include_special_chars": False},
headers=headers, headers=headers,
@@ -496,7 +520,8 @@ def test_generate_password_no_special_chars(client):
assert not any(c in string.punctuation for c in pw) assert not any(c in string.punctuation for c in pw)
def test_generate_password_allowed_chars(client): @pytest.mark.anyio
async def test_generate_password_allowed_chars(client):
cl, token = client cl, token = client
class DummyEnc: class DummyEnc:
@@ -514,7 +539,7 @@ def test_generate_password_allowed_chars(client):
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
allowed = "@$" allowed = "@$"
res = cl.post( res = await cl.post(
"/api/v1/password", "/api/v1/password",
json={"length": 16, "allowed_special_chars": allowed}, json={"length": 16, "allowed_special_chars": allowed},
headers=headers, headers=headers,

View File

@@ -1,15 +1,19 @@
from test_api import client from test_api import client
from types import SimpleNamespace from types import SimpleNamespace
import queue import queue
import pytest
import seedpass.api as api import seedpass.api as api
def test_notifications_endpoint(client): @pytest.mark.anyio
async def test_notifications_endpoint(client):
cl, token = client cl, token = client
api.app.state.pm.notifications = queue.Queue() api.app.state.pm.notifications = queue.Queue()
api.app.state.pm.notifications.put(SimpleNamespace(message="m1", level="INFO")) api.app.state.pm.notifications.put(SimpleNamespace(message="m1", level="INFO"))
api.app.state.pm.notifications.put(SimpleNamespace(message="m2", level="WARNING")) api.app.state.pm.notifications.put(SimpleNamespace(message="m2", level="WARNING"))
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) res = await cl.get(
"/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}
)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == [ assert res.json() == [
{"level": "INFO", "message": "m1"}, {"level": "INFO", "message": "m1"},
@@ -18,19 +22,25 @@ def test_notifications_endpoint(client):
assert api.app.state.pm.notifications.empty() assert api.app.state.pm.notifications.empty()
def test_notifications_endpoint_clears_queue(client): @pytest.mark.anyio
async def test_notifications_endpoint_clears_queue(client):
cl, token = client cl, token = client
api.app.state.pm.notifications = queue.Queue() api.app.state.pm.notifications = queue.Queue()
api.app.state.pm.notifications.put(SimpleNamespace(message="hi", level="INFO")) api.app.state.pm.notifications.put(SimpleNamespace(message="hi", level="INFO"))
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) res = await cl.get(
"/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}
)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == [{"level": "INFO", "message": "hi"}] assert res.json() == [{"level": "INFO", "message": "hi"}]
assert api.app.state.pm.notifications.empty() assert api.app.state.pm.notifications.empty()
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) res = await cl.get(
"/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}
)
assert res.json() == [] assert res.json() == []
def test_notifications_endpoint_does_not_clear_current(client): @pytest.mark.anyio
async def test_notifications_endpoint_does_not_clear_current(client):
cl, token = client cl, token = client
api.app.state.pm.notifications = queue.Queue() api.app.state.pm.notifications = queue.Queue()
msg = SimpleNamespace(message="keep", level="INFO") msg = SimpleNamespace(message="keep", level="INFO")
@@ -40,7 +50,9 @@ def test_notifications_endpoint_does_not_clear_current(client):
lambda: api.app.state.pm._current_notification lambda: api.app.state.pm._current_notification
) )
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) res = await cl.get(
"/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}
)
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == [{"level": "INFO", "message": "keep"}] assert res.json() == [{"level": "INFO", "message": "keep"}]
assert api.app.state.pm.notifications.empty() assert api.app.state.pm.notifications.empty()

View File

@@ -1,13 +1,14 @@
from test_api import client from test_api import client
import pytest
def test_profile_stats_endpoint(client): @pytest.mark.anyio
async def test_profile_stats_endpoint(client):
cl, token = client cl, token = client
stats = {"total_entries": 1} stats = {"total_entries": 1}
# monkeypatch set _pm.get_profile_stats after client fixture started
import seedpass.api as api import seedpass.api as api
api.app.state.pm.get_profile_stats = lambda: stats api.app.state.pm.get_profile_stats = lambda: stats
res = cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"}) res = await cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"})
assert res.status_code == 200 assert res.status_code == 200
assert res.json() == stats assert res.json() == stats

View File

@@ -1,15 +1,17 @@
import importlib import importlib
from pathlib import Path from pathlib import Path
from types import SimpleNamespace from types import SimpleNamespace
import importlib
from fastapi.testclient import TestClient import pytest
from httpx import ASGITransport, AsyncClient
import sys import sys
sys.path.append(str(Path(__file__).resolve().parents[1])) sys.path.append(str(Path(__file__).resolve().parents[1]))
def test_rate_limit_exceeded(monkeypatch): @pytest.mark.anyio
async def test_rate_limit_exceeded(monkeypatch):
monkeypatch.setenv("SEEDPASS_RATE_LIMIT", "2") monkeypatch.setenv("SEEDPASS_RATE_LIMIT", "2")
monkeypatch.setenv("SEEDPASS_RATE_WINDOW", "60") monkeypatch.setenv("SEEDPASS_RATE_WINDOW", "60")
import seedpass.api as api import seedpass.api as api
@@ -31,12 +33,15 @@ def test_rate_limit_exceeded(monkeypatch):
) )
monkeypatch.setattr(api, "PasswordManager", lambda: dummy) monkeypatch.setattr(api, "PasswordManager", lambda: dummy)
token = api.start_server() token = api.start_server()
client = TestClient(api.app) transport = ASGITransport(app=api.app)
async with AsyncClient(transport=transport, base_url="http://test") as client:
headers = {"Authorization": f"Bearer {token}"} headers = {"Authorization": f"Bearer {token}"}
for _ in range(2): for _ in range(2):
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers) res = await client.get(
"/api/v1/entry", params={"query": "s"}, headers=headers
)
assert res.status_code == 200 assert res.status_code == 200
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers) res = await client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
assert res.status_code == 429 assert res.status_code == 429