mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-08 07:18:47 +00:00
Merge pull request #771 from PR0M3TH3AN/codex/convert-api-functions-to-async
Refactor API endpoints to async
This commit is contained in:
@@ -14,6 +14,7 @@ import jwt
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI, Header, HTTPException, Request, Response
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
import asyncio
|
||||
import sys
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -132,12 +133,12 @@ def _validate_encryption_path(request: Request, path: Path) -> Path:
|
||||
|
||||
|
||||
@app.get("/api/v1/entry")
|
||||
def search_entry(
|
||||
async def search_entry(
|
||||
request: Request, query: str, authorization: str | None = Header(None)
|
||||
) -> List[Any]:
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
results = pm.entry_manager.search_entries(query)
|
||||
results = await run_in_threadpool(pm.entry_manager.search_entries, query)
|
||||
return [
|
||||
{
|
||||
"id": idx,
|
||||
@@ -152,7 +153,7 @@ def search_entry(
|
||||
|
||||
|
||||
@app.get("/api/v1/entry/{entry_id}")
|
||||
def get_entry(
|
||||
async def get_entry(
|
||||
request: Request,
|
||||
entry_id: int,
|
||||
authorization: str | None = Header(None),
|
||||
@@ -161,14 +162,14 @@ def get_entry(
|
||||
_check_token(request, authorization)
|
||||
_require_password(request, password)
|
||||
pm = _get_pm(request)
|
||||
entry = pm.entry_manager.retrieve_entry(entry_id)
|
||||
entry = await run_in_threadpool(pm.entry_manager.retrieve_entry, entry_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return entry
|
||||
|
||||
|
||||
@app.post("/api/v1/entry")
|
||||
def create_entry(
|
||||
async def create_entry(
|
||||
request: Request,
|
||||
entry: dict,
|
||||
authorization: str | None = Header(None),
|
||||
@@ -197,7 +198,8 @@ def create_entry(
|
||||
]
|
||||
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
|
||||
|
||||
index = pm.entry_manager.add_entry(
|
||||
index = await run_in_threadpool(
|
||||
pm.entry_manager.add_entry,
|
||||
entry.get("label"),
|
||||
int(entry.get("length", 12)),
|
||||
entry.get("username"),
|
||||
@@ -207,9 +209,10 @@ def create_entry(
|
||||
return {"id": index}
|
||||
|
||||
if etype == "totp":
|
||||
index = pm.entry_manager.get_next_index()
|
||||
index = await run_in_threadpool(pm.entry_manager.get_next_index)
|
||||
|
||||
uri = pm.entry_manager.add_totp(
|
||||
uri = await run_in_threadpool(
|
||||
pm.entry_manager.add_totp,
|
||||
entry.get("label"),
|
||||
pm.parent_seed,
|
||||
secret=entry.get("secret"),
|
||||
@@ -222,7 +225,8 @@ def create_entry(
|
||||
return {"id": index, "uri": uri}
|
||||
|
||||
if etype == "ssh":
|
||||
index = pm.entry_manager.add_ssh_key(
|
||||
index = await run_in_threadpool(
|
||||
pm.entry_manager.add_ssh_key,
|
||||
entry.get("label"),
|
||||
pm.parent_seed,
|
||||
index=entry.get("index"),
|
||||
@@ -232,7 +236,8 @@ def create_entry(
|
||||
return {"id": index}
|
||||
|
||||
if etype == "pgp":
|
||||
index = pm.entry_manager.add_pgp_key(
|
||||
index = await run_in_threadpool(
|
||||
pm.entry_manager.add_pgp_key,
|
||||
entry.get("label"),
|
||||
pm.parent_seed,
|
||||
index=entry.get("index"),
|
||||
@@ -244,7 +249,8 @@ def create_entry(
|
||||
return {"id": index}
|
||||
|
||||
if etype == "nostr":
|
||||
index = pm.entry_manager.add_nostr_key(
|
||||
index = await run_in_threadpool(
|
||||
pm.entry_manager.add_nostr_key,
|
||||
entry.get("label"),
|
||||
pm.parent_seed,
|
||||
index=entry.get("index"),
|
||||
@@ -254,7 +260,8 @@ def create_entry(
|
||||
return {"id": index}
|
||||
|
||||
if etype == "key_value":
|
||||
index = pm.entry_manager.add_key_value(
|
||||
index = await run_in_threadpool(
|
||||
pm.entry_manager.add_key_value,
|
||||
entry.get("label"),
|
||||
entry.get("key"),
|
||||
entry.get("value"),
|
||||
@@ -268,7 +275,8 @@ def create_entry(
|
||||
if etype == "seed"
|
||||
else pm.entry_manager.add_managed_account
|
||||
)
|
||||
index = func(
|
||||
index = await run_in_threadpool(
|
||||
func,
|
||||
entry.get("label"),
|
||||
pm.parent_seed,
|
||||
index=entry.get("index"),
|
||||
|
@@ -1,7 +1,15 @@
|
||||
import importlib.util
|
||||
import logging
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture(
|
||||
params=["asyncio"] + (["trio"] if importlib.util.find_spec("trio") else [])
|
||||
)
|
||||
def anyio_backend(request):
|
||||
return request.param
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mute_logging():
|
||||
logging.getLogger().setLevel(logging.WARNING)
|
||||
|
@@ -3,7 +3,7 @@ from pathlib import Path
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
import hashlib
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||
@@ -13,7 +13,7 @@ from seedpass.core.entry_types import EntryType
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client(monkeypatch):
|
||||
async def client(monkeypatch):
|
||||
dummy = SimpleNamespace(
|
||||
entry_manager=SimpleNamespace(
|
||||
search_entries=lambda q: [
|
||||
@@ -45,27 +45,31 @@ def client(monkeypatch):
|
||||
monkeypatch.setattr(api, "PasswordManager", lambda: dummy)
|
||||
monkeypatch.setenv("SEEDPASS_CORS_ORIGINS", "http://example.com")
|
||||
token = api.start_server()
|
||||
client = TestClient(api.app)
|
||||
return client, token
|
||||
transport = ASGITransport(app=api.app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as ac:
|
||||
yield ac, token
|
||||
|
||||
|
||||
def test_token_hashed(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_token_hashed(client):
|
||||
_, token = client
|
||||
assert api.app.state.token_hash != token
|
||||
assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def test_cors_and_auth(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_cors_and_auth(client):
|
||||
cl, token = client
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
res = cl.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||
res = await cl.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.headers.get("access-control-allow-origin") == "http://example.com"
|
||||
|
||||
|
||||
def test_invalid_token(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_invalid_token(client):
|
||||
cl, _token = client
|
||||
res = cl.get(
|
||||
res = await cl.get(
|
||||
"/api/v1/entry",
|
||||
params={"query": "s"},
|
||||
headers={"Authorization": "Bearer bad"},
|
||||
@@ -73,60 +77,65 @@ def test_invalid_token(client):
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_get_entry_by_id(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_get_entry_by_id(client):
|
||||
cl, token = client
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": "http://example.com",
|
||||
"X-SeedPass-Password": "pw",
|
||||
}
|
||||
res = cl.get("/api/v1/entry/1", headers=headers)
|
||||
res = await cl.get("/api/v1/entry/1", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"label": "Site"}
|
||||
assert res.headers.get("access-control-allow-origin") == "http://example.com"
|
||||
|
||||
|
||||
def test_get_config_value(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_get_config_value(client):
|
||||
cl, token = client
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": "http://example.com",
|
||||
}
|
||||
res = cl.get("/api/v1/config/k", headers=headers)
|
||||
res = await cl.get("/api/v1/config/k", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"key": "k", "value": "v"}
|
||||
assert res.headers.get("access-control-allow-origin") == "http://example.com"
|
||||
|
||||
|
||||
def test_list_fingerprint(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_list_fingerprint(client):
|
||||
cl, token = client
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": "http://example.com",
|
||||
}
|
||||
res = cl.get("/api/v1/fingerprint", headers=headers)
|
||||
res = await cl.get("/api/v1/fingerprint", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == ["fp"]
|
||||
assert res.headers.get("access-control-allow-origin") == "http://example.com"
|
||||
|
||||
|
||||
def test_get_nostr_pubkey(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_get_nostr_pubkey(client):
|
||||
cl, token = client
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": "http://example.com",
|
||||
}
|
||||
res = cl.get("/api/v1/nostr/pubkey", headers=headers)
|
||||
res = await cl.get("/api/v1/nostr/pubkey", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"npub": "np"}
|
||||
assert res.headers.get("access-control-allow-origin") == "http://example.com"
|
||||
|
||||
|
||||
def test_create_modify_archive_entry(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_create_modify_archive_entry(client):
|
||||
cl, token = client
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/entry",
|
||||
json={"label": "test", "length": 12},
|
||||
headers=headers,
|
||||
@@ -134,7 +143,7 @@ def test_create_modify_archive_entry(client):
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"id": 1}
|
||||
|
||||
res = cl.put(
|
||||
res = await cl.put(
|
||||
"/api/v1/entry/1",
|
||||
json={"username": "bob"},
|
||||
headers=headers,
|
||||
@@ -142,16 +151,17 @@ def test_create_modify_archive_entry(client):
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "ok"}
|
||||
|
||||
res = cl.post("/api/v1/entry/1/archive", headers=headers)
|
||||
res = await cl.post("/api/v1/entry/1/archive", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "archived"}
|
||||
|
||||
res = cl.post("/api/v1/entry/1/unarchive", headers=headers)
|
||||
res = await cl.post("/api/v1/entry/1/unarchive", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "active"}
|
||||
|
||||
|
||||
def test_update_config(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_update_config(client):
|
||||
cl, token = client
|
||||
called = {}
|
||||
|
||||
@@ -160,7 +170,7 @@ def test_update_config(client):
|
||||
|
||||
api.app.state.pm.config_manager.set_inactivity_timeout = set_timeout
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
res = cl.put(
|
||||
res = await cl.put(
|
||||
"/api/v1/config/inactivity_timeout",
|
||||
json={"value": 42},
|
||||
headers=headers,
|
||||
@@ -171,14 +181,15 @@ def test_update_config(client):
|
||||
assert res.headers.get("access-control-allow-origin") == "http://example.com"
|
||||
|
||||
|
||||
def test_update_config_quick_unlock(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_update_config_quick_unlock(client):
|
||||
cl, token = client
|
||||
called = {}
|
||||
api.app.state.pm.config_manager.set_quick_unlock = lambda v: called.setdefault(
|
||||
"val", v
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
res = cl.put(
|
||||
res = await cl.put(
|
||||
"/api/v1/config/quick_unlock",
|
||||
json={"value": True},
|
||||
headers=headers,
|
||||
@@ -188,12 +199,13 @@ def test_update_config_quick_unlock(client):
|
||||
assert called.get("val") is True
|
||||
|
||||
|
||||
def test_change_password_route(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_change_password_route(client):
|
||||
cl, token = client
|
||||
called = {}
|
||||
api.app.state.pm.change_password = lambda o, n: called.setdefault("called", (o, n))
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/change-password",
|
||||
headers=headers,
|
||||
json={"old": "old", "new": "new"},
|
||||
@@ -204,10 +216,11 @@ def test_change_password_route(client):
|
||||
assert res.headers.get("access-control-allow-origin") == "http://example.com"
|
||||
|
||||
|
||||
def test_update_config_unknown_key(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_update_config_unknown_key(client):
|
||||
cl, token = client
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
res = cl.put(
|
||||
res = await cl.put(
|
||||
"/api/v1/config/bogus",
|
||||
json={"value": 1},
|
||||
headers=headers,
|
||||
@@ -215,7 +228,8 @@ def test_update_config_unknown_key(client):
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_shutdown(client, monkeypatch):
|
||||
@pytest.mark.anyio
|
||||
async def test_shutdown(client, monkeypatch):
|
||||
cl, token = client
|
||||
|
||||
calls = {}
|
||||
@@ -231,7 +245,7 @@ def test_shutdown(client, monkeypatch):
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Origin": "http://example.com",
|
||||
}
|
||||
res = cl.post("/api/v1/shutdown", headers=headers)
|
||||
res = await cl.post("/api/v1/shutdown", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "shutting down"}
|
||||
assert calls["func"] is sys.exit
|
||||
@@ -239,6 +253,7 @@ def test_shutdown(client, monkeypatch):
|
||||
assert res.headers.get("access-control-allow-origin") == "http://example.com"
|
||||
|
||||
|
||||
@pytest.mark.anyio
|
||||
@pytest.mark.parametrize(
|
||||
"method,path",
|
||||
[
|
||||
@@ -256,11 +271,11 @@ def test_shutdown(client, monkeypatch):
|
||||
("post", "/api/v1/vault/lock"),
|
||||
],
|
||||
)
|
||||
def test_invalid_token_other_endpoints(client, method, path):
|
||||
async def test_invalid_token_other_endpoints(client, method, path):
|
||||
cl, _token = client
|
||||
req = getattr(cl, method)
|
||||
kwargs = {"headers": {"Authorization": "Bearer bad"}}
|
||||
if method in {"post", "put"}:
|
||||
kwargs["json"] = {}
|
||||
res = req(path, **kwargs)
|
||||
res = await req(path, **kwargs)
|
||||
assert res.status_code == 401
|
||||
|
@@ -13,7 +13,8 @@ from seedpass.core.encryption import EncryptionManager
|
||||
from nostr.client import NostrClient, DEFAULT_RELAYS
|
||||
|
||||
|
||||
def test_create_and_modify_totp_entry(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_create_and_modify_totp_entry(client):
|
||||
cl, token = client
|
||||
calls = {}
|
||||
|
||||
@@ -30,7 +31,7 @@ def test_create_and_modify_totp_entry(client):
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/entry",
|
||||
json={
|
||||
"type": "totp",
|
||||
@@ -54,7 +55,7 @@ def test_create_and_modify_totp_entry(client):
|
||||
"archived": False,
|
||||
}
|
||||
|
||||
res = cl.put(
|
||||
res = await cl.put(
|
||||
"/api/v1/entry/5",
|
||||
json={"period": 90, "digits": 6},
|
||||
headers=headers,
|
||||
@@ -65,7 +66,8 @@ def test_create_and_modify_totp_entry(client):
|
||||
assert calls["modify"][1]["digits"] == 6
|
||||
|
||||
|
||||
def test_create_and_modify_ssh_entry(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_create_and_modify_ssh_entry(client):
|
||||
cl, token = client
|
||||
calls = {}
|
||||
|
||||
@@ -81,7 +83,7 @@ def test_create_and_modify_ssh_entry(client):
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/entry",
|
||||
json={"type": "ssh", "label": "S", "index": 2, "notes": "n"},
|
||||
headers=headers,
|
||||
@@ -90,7 +92,7 @@ def test_create_and_modify_ssh_entry(client):
|
||||
assert res.json() == {"id": 2}
|
||||
assert calls["create"] == {"index": 2, "notes": "n", "archived": False}
|
||||
|
||||
res = cl.put(
|
||||
res = await cl.put(
|
||||
"/api/v1/entry/2",
|
||||
json={"notes": "x"},
|
||||
headers=headers,
|
||||
@@ -100,7 +102,8 @@ def test_create_and_modify_ssh_entry(client):
|
||||
assert calls["modify"][1]["notes"] == "x"
|
||||
|
||||
|
||||
def test_update_entry_error(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_update_entry_error(client):
|
||||
cl, token = client
|
||||
|
||||
def modify(*a, **k):
|
||||
@@ -108,12 +111,13 @@ def test_update_entry_error(client):
|
||||
|
||||
api.app.state.pm.entry_manager.modify_entry = modify
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers)
|
||||
res = await cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers)
|
||||
assert res.status_code == 400
|
||||
assert res.json() == {"detail": "nope"}
|
||||
|
||||
|
||||
def test_update_config_secret_mode(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_update_config_secret_mode(client):
|
||||
cl, token = client
|
||||
called = {}
|
||||
|
||||
@@ -122,7 +126,7 @@ def test_update_config_secret_mode(client):
|
||||
|
||||
api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.put(
|
||||
res = await cl.put(
|
||||
"/api/v1/config/secret_mode_enabled",
|
||||
json={"value": True},
|
||||
headers=headers,
|
||||
@@ -132,17 +136,19 @@ def test_update_config_secret_mode(client):
|
||||
assert called["val"] is True
|
||||
|
||||
|
||||
def test_totp_export_endpoint(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_totp_export_endpoint(client):
|
||||
cl, token = client
|
||||
api.app.state.pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]}
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"}
|
||||
res = cl.get("/api/v1/totp/export", headers=headers)
|
||||
res = await cl.get("/api/v1/totp/export", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"entries": ["x"]}
|
||||
|
||||
|
||||
def test_totp_codes_endpoint(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_totp_codes_endpoint(client):
|
||||
cl, token = client
|
||||
api.app.state.pm.entry_manager.list_entries = lambda **kw: [
|
||||
(0, "Email", None, None, False)
|
||||
@@ -151,7 +157,7 @@ def test_totp_codes_endpoint(client):
|
||||
api.app.state.pm.entry_manager.get_totp_time_remaining = lambda i: 30
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"}
|
||||
res = cl.get("/api/v1/totp", headers=headers)
|
||||
res = await cl.get("/api/v1/totp", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {
|
||||
"codes": [
|
||||
@@ -160,13 +166,17 @@ def test_totp_codes_endpoint(client):
|
||||
}
|
||||
|
||||
|
||||
def test_parent_seed_endpoint_removed(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_parent_seed_endpoint_removed(client):
|
||||
cl, token = client
|
||||
res = cl.get("/api/v1/parent-seed", headers={"Authorization": f"Bearer {token}"})
|
||||
res = await cl.get(
|
||||
"/api/v1/parent-seed", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert res.status_code == 404
|
||||
|
||||
|
||||
def test_fingerprint_endpoints(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_fingerprint_endpoints(client):
|
||||
cl, token = client
|
||||
calls = {}
|
||||
|
||||
@@ -178,17 +188,17 @@ def test_fingerprint_endpoints(client):
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
res = cl.post("/api/v1/fingerprint", headers=headers)
|
||||
res = await cl.post("/api/v1/fingerprint", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "ok"}
|
||||
assert calls.get("add") is True
|
||||
|
||||
res = cl.delete("/api/v1/fingerprint/abc", headers=headers)
|
||||
res = await cl.delete("/api/v1/fingerprint/abc", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "deleted"}
|
||||
assert calls.get("remove") == "abc"
|
||||
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/fingerprint/select",
|
||||
json={"fingerprint": "xyz"},
|
||||
headers=headers,
|
||||
@@ -198,7 +208,8 @@ def test_fingerprint_endpoints(client):
|
||||
assert calls.get("select") == "xyz"
|
||||
|
||||
|
||||
def test_checksum_endpoints(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_checksum_endpoints(client):
|
||||
cl, token = client
|
||||
calls = {}
|
||||
|
||||
@@ -209,18 +220,19 @@ def test_checksum_endpoints(client):
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
res = cl.post("/api/v1/checksum/verify", headers=headers)
|
||||
res = await cl.post("/api/v1/checksum/verify", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "ok"}
|
||||
assert calls.get("verify") is True
|
||||
|
||||
res = cl.post("/api/v1/checksum/update", headers=headers)
|
||||
res = await cl.post("/api/v1/checksum/update", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "ok"}
|
||||
assert calls.get("update") is True
|
||||
|
||||
|
||||
def test_vault_import_via_path(client, tmp_path):
|
||||
@pytest.mark.anyio
|
||||
async def test_vault_import_via_path(client, tmp_path):
|
||||
cl, token = client
|
||||
called = {}
|
||||
|
||||
@@ -236,7 +248,7 @@ def test_vault_import_via_path(client, tmp_path):
|
||||
file_path.write_text("{}")
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/vault/import",
|
||||
json={"path": str(file_path)},
|
||||
headers=headers,
|
||||
@@ -247,7 +259,8 @@ def test_vault_import_via_path(client, tmp_path):
|
||||
assert called.get("sync") is True
|
||||
|
||||
|
||||
def test_vault_import_via_upload(client, tmp_path):
|
||||
@pytest.mark.anyio
|
||||
async def test_vault_import_via_upload(client, tmp_path):
|
||||
cl, token = client
|
||||
called = {}
|
||||
|
||||
@@ -261,7 +274,7 @@ def test_vault_import_via_upload(client, tmp_path):
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
with open(file_path, "rb") as fh:
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/vault/import",
|
||||
files={"file": ("c.json", fh.read())},
|
||||
headers=headers,
|
||||
@@ -272,7 +285,8 @@ def test_vault_import_via_upload(client, tmp_path):
|
||||
assert called.get("sync") is True
|
||||
|
||||
|
||||
def test_vault_import_invalid_extension(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_vault_import_invalid_extension(client):
|
||||
cl, token = client
|
||||
api.app.state.pm.handle_import_database = lambda path: None
|
||||
api.app.state.pm.sync_vault = lambda: None
|
||||
@@ -281,7 +295,7 @@ def test_vault_import_invalid_extension(client):
|
||||
)
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/vault/import",
|
||||
json={"path": "bad.txt"},
|
||||
headers=headers,
|
||||
@@ -289,7 +303,8 @@ def test_vault_import_invalid_extension(client):
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_vault_import_path_traversal_blocked(client, tmp_path):
|
||||
@pytest.mark.anyio
|
||||
async def test_vault_import_path_traversal_blocked(client, tmp_path):
|
||||
cl, token = client
|
||||
key = base64.urlsafe_b64encode(os.urandom(32))
|
||||
api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path)
|
||||
@@ -297,7 +312,7 @@ def test_vault_import_path_traversal_blocked(client, tmp_path):
|
||||
api.app.state.pm.sync_vault = lambda: None
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/vault/import",
|
||||
json={"path": "../evil.json.enc"},
|
||||
headers=headers,
|
||||
@@ -305,7 +320,8 @@ def test_vault_import_path_traversal_blocked(client, tmp_path):
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_vault_lock_endpoint(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_vault_lock_endpoint(client):
|
||||
cl, token = client
|
||||
called = {}
|
||||
|
||||
@@ -317,7 +333,7 @@ def test_vault_lock_endpoint(client):
|
||||
api.app.state.pm.locked = False
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post("/api/v1/vault/lock", headers=headers)
|
||||
res = await cl.post("/api/v1/vault/lock", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "locked"}
|
||||
assert called.get("locked") is True
|
||||
@@ -329,7 +345,8 @@ def test_vault_lock_endpoint(client):
|
||||
assert api.app.state.pm.locked is False
|
||||
|
||||
|
||||
def test_secret_mode_endpoint(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_secret_mode_endpoint(client):
|
||||
cl, token = client
|
||||
called = {}
|
||||
|
||||
@@ -343,7 +360,7 @@ def test_secret_mode_endpoint(client):
|
||||
api.app.state.pm.config_manager.set_clipboard_clear_delay = set_delay
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/secret-mode",
|
||||
json={"enabled": True, "delay": 12},
|
||||
headers=headers,
|
||||
@@ -354,7 +371,8 @@ def test_secret_mode_endpoint(client):
|
||||
assert called["delay"] == 12
|
||||
|
||||
|
||||
def test_vault_export_endpoint(client, tmp_path):
|
||||
@pytest.mark.anyio
|
||||
async def test_vault_export_endpoint(client, tmp_path):
|
||||
cl, token = client
|
||||
out = tmp_path / "out.json"
|
||||
out.write_text("data")
|
||||
@@ -365,15 +383,18 @@ def test_vault_export_endpoint(client, tmp_path):
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-SeedPass-Password": "pw",
|
||||
}
|
||||
res = cl.post("/api/v1/vault/export", headers=headers)
|
||||
res = await cl.post("/api/v1/vault/export", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.content == b"data"
|
||||
|
||||
res = cl.post("/api/v1/vault/export", headers={"Authorization": f"Bearer {token}"})
|
||||
res = await cl.post(
|
||||
"/api/v1/vault/export", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert res.status_code == 401
|
||||
|
||||
|
||||
def test_backup_parent_seed_endpoint(client, tmp_path):
|
||||
@pytest.mark.anyio
|
||||
async def test_backup_parent_seed_endpoint(client, tmp_path):
|
||||
cl, token = client
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
called = {}
|
||||
@@ -386,7 +407,7 @@ def test_backup_parent_seed_endpoint(client, tmp_path):
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-SeedPass-Password": "pw",
|
||||
}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/vault/backup-parent-seed",
|
||||
json={"path": str(path), "confirm": True},
|
||||
headers=headers,
|
||||
@@ -395,7 +416,7 @@ def test_backup_parent_seed_endpoint(client, tmp_path):
|
||||
assert res.json() == {"status": "saved", "path": str(path)}
|
||||
assert called["path"] == path
|
||||
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/vault/backup-parent-seed",
|
||||
json={"path": str(path)},
|
||||
headers=headers,
|
||||
@@ -403,7 +424,8 @@ def test_backup_parent_seed_endpoint(client, tmp_path):
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
|
||||
@pytest.mark.anyio
|
||||
async def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
|
||||
cl, token = client
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
key = base64.urlsafe_b64encode(os.urandom(32))
|
||||
@@ -412,7 +434,7 @@ def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-SeedPass-Password": "pw",
|
||||
}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/vault/backup-parent-seed",
|
||||
json={"path": "../evil.enc", "confirm": True},
|
||||
headers=headers,
|
||||
@@ -420,7 +442,8 @@ def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
|
||||
assert res.status_code == 400
|
||||
|
||||
|
||||
def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
|
||||
@pytest.mark.anyio
|
||||
async def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
|
||||
cl, token = client
|
||||
nostr_client, _ = dummy_nostr_client
|
||||
relays = ["wss://a", "wss://b"]
|
||||
@@ -448,28 +471,29 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
res = cl.get("/api/v1/relays", headers=headers)
|
||||
res = await cl.get("/api/v1/relays", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"relays": relays}
|
||||
|
||||
res = cl.post("/api/v1/relays", json={"url": "wss://c"}, headers=headers)
|
||||
res = await cl.post("/api/v1/relays", json={"url": "wss://c"}, headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert called["set"] == ["wss://a", "wss://b", "wss://c"]
|
||||
|
||||
api.app.state.pm.config_manager.load_config = lambda require_pin=False: {
|
||||
"relays": ["wss://a", "wss://b", "wss://c"]
|
||||
}
|
||||
res = cl.delete("/api/v1/relays/2", headers=headers)
|
||||
res = await cl.delete("/api/v1/relays/2", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert called["set"] == ["wss://a", "wss://c"]
|
||||
|
||||
res = cl.post("/api/v1/relays/reset", headers=headers)
|
||||
res = await cl.post("/api/v1/relays/reset", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert called.get("init") is True
|
||||
assert api.app.state.pm.nostr_client.relays == list(DEFAULT_RELAYS)
|
||||
|
||||
|
||||
def test_generate_password_no_special_chars(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_generate_password_no_special_chars(client):
|
||||
cl, token = client
|
||||
|
||||
class DummyEnc:
|
||||
@@ -486,7 +510,7 @@ def test_generate_password_no_special_chars(client):
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/password",
|
||||
json={"length": 16, "include_special_chars": False},
|
||||
headers=headers,
|
||||
@@ -496,7 +520,8 @@ def test_generate_password_no_special_chars(client):
|
||||
assert not any(c in string.punctuation for c in pw)
|
||||
|
||||
|
||||
def test_generate_password_allowed_chars(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_generate_password_allowed_chars(client):
|
||||
cl, token = client
|
||||
|
||||
class DummyEnc:
|
||||
@@ -514,7 +539,7 @@ def test_generate_password_allowed_chars(client):
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
allowed = "@$"
|
||||
res = cl.post(
|
||||
res = await cl.post(
|
||||
"/api/v1/password",
|
||||
json={"length": 16, "allowed_special_chars": allowed},
|
||||
headers=headers,
|
||||
|
@@ -1,15 +1,19 @@
|
||||
from test_api import client
|
||||
from types import SimpleNamespace
|
||||
import queue
|
||||
import pytest
|
||||
import seedpass.api as api
|
||||
|
||||
|
||||
def test_notifications_endpoint(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_notifications_endpoint(client):
|
||||
cl, token = client
|
||||
api.app.state.pm.notifications = queue.Queue()
|
||||
api.app.state.pm.notifications.put(SimpleNamespace(message="m1", level="INFO"))
|
||||
api.app.state.pm.notifications.put(SimpleNamespace(message="m2", level="WARNING"))
|
||||
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
|
||||
res = await cl.get(
|
||||
"/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == [
|
||||
{"level": "INFO", "message": "m1"},
|
||||
@@ -18,19 +22,25 @@ def test_notifications_endpoint(client):
|
||||
assert api.app.state.pm.notifications.empty()
|
||||
|
||||
|
||||
def test_notifications_endpoint_clears_queue(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_notifications_endpoint_clears_queue(client):
|
||||
cl, token = client
|
||||
api.app.state.pm.notifications = queue.Queue()
|
||||
api.app.state.pm.notifications.put(SimpleNamespace(message="hi", level="INFO"))
|
||||
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
|
||||
res = await cl.get(
|
||||
"/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == [{"level": "INFO", "message": "hi"}]
|
||||
assert api.app.state.pm.notifications.empty()
|
||||
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
|
||||
res = await cl.get(
|
||||
"/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert res.json() == []
|
||||
|
||||
|
||||
def test_notifications_endpoint_does_not_clear_current(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_notifications_endpoint_does_not_clear_current(client):
|
||||
cl, token = client
|
||||
api.app.state.pm.notifications = queue.Queue()
|
||||
msg = SimpleNamespace(message="keep", level="INFO")
|
||||
@@ -40,7 +50,9 @@ def test_notifications_endpoint_does_not_clear_current(client):
|
||||
lambda: api.app.state.pm._current_notification
|
||||
)
|
||||
|
||||
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
|
||||
res = await cl.get(
|
||||
"/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}
|
||||
)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == [{"level": "INFO", "message": "keep"}]
|
||||
assert api.app.state.pm.notifications.empty()
|
||||
|
@@ -1,13 +1,14 @@
|
||||
from test_api import client
|
||||
import pytest
|
||||
|
||||
|
||||
def test_profile_stats_endpoint(client):
|
||||
@pytest.mark.anyio
|
||||
async def test_profile_stats_endpoint(client):
|
||||
cl, token = client
|
||||
stats = {"total_entries": 1}
|
||||
# monkeypatch set _pm.get_profile_stats after client fixture started
|
||||
import seedpass.api as api
|
||||
|
||||
api.app.state.pm.get_profile_stats = lambda: stats
|
||||
res = cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"})
|
||||
res = await cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"})
|
||||
assert res.status_code == 200
|
||||
assert res.json() == stats
|
||||
|
@@ -1,15 +1,17 @@
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
import importlib
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||
|
||||
|
||||
def test_rate_limit_exceeded(monkeypatch):
|
||||
@pytest.mark.anyio
|
||||
async def test_rate_limit_exceeded(monkeypatch):
|
||||
monkeypatch.setenv("SEEDPASS_RATE_LIMIT", "2")
|
||||
monkeypatch.setenv("SEEDPASS_RATE_WINDOW", "60")
|
||||
import seedpass.api as api
|
||||
@@ -31,12 +33,15 @@ def test_rate_limit_exceeded(monkeypatch):
|
||||
)
|
||||
monkeypatch.setattr(api, "PasswordManager", lambda: dummy)
|
||||
token = api.start_server()
|
||||
client = TestClient(api.app)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
transport = ASGITransport(app=api.app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
for _ in range(2):
|
||||
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||
assert res.status_code == 200
|
||||
for _ in range(2):
|
||||
res = await client.get(
|
||||
"/api/v1/entry", params={"query": "s"}, headers=headers
|
||||
)
|
||||
assert res.status_code == 200
|
||||
|
||||
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||
assert res.status_code == 429
|
||||
res = await client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||
assert res.status_code == 429
|
||||
|
Reference in New Issue
Block a user