Test API rate limiting

This commit is contained in:
thePR0M3TH3AN
2025-08-03 08:41:22 -04:00
parent c7df96aac5
commit e5f1158101
3 changed files with 79 additions and 0 deletions

View File

@@ -37,3 +37,4 @@ argon2-cffi
toga-core>=0.5.2
pillow
toga-dummy>=0.5.2 # for headless GUI tests
slowapi

View File

@@ -17,11 +17,21 @@ import asyncio
import sys
from fastapi.middleware.cors import CORSMiddleware
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.util import get_remote_address
from slowapi.middleware import SlowAPIMiddleware
from seedpass.core.manager import PasswordManager
from seedpass.core.entry_types import EntryType
from seedpass.core.api import UtilityService
_RATE_LIMIT = int(os.getenv("SEEDPASS_RATE_LIMIT", "100"))
_RATE_WINDOW = int(os.getenv("SEEDPASS_RATE_WINDOW", "60"))
_RATE_LIMIT_STR = f"{_RATE_LIMIT}/{_RATE_WINDOW} seconds"
limiter = Limiter(key_func=get_remote_address, default_limits=[_RATE_LIMIT_STR])
app = FastAPI()
_pm: Optional[PasswordManager] = None
@@ -71,6 +81,10 @@ def start_server(fingerprint: str | None = None) -> str:
_jwt_secret = secrets.token_urlsafe(32)
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
if not getattr(app.state, "limiter", None):
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
app.add_middleware(SlowAPIMiddleware)
origins = [
o.strip()
for o in os.getenv("SEEDPASS_CORS_ORIGINS", "").split(",")
@@ -157,6 +171,7 @@ def create_entry(
"min_special",
]
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
index = _pm.entry_manager.add_entry(
entry.get("label"),
int(entry.get("length", 12)),
@@ -168,6 +183,7 @@ def create_entry(
if etype == "totp":
index = _pm.entry_manager.get_next_index()
uri = _pm.entry_manager.add_totp(
entry.get("label"),
_pm.parent_seed,
@@ -295,6 +311,7 @@ def get_config(key: str, authorization: str | None = Header(None)) -> Any:
_check_token(authorization)
assert _pm is not None
value = _pm.config_manager.load_config(require_pin=False).get(key)
if value is None:
raise HTTPException(status_code=404, detail="Not found")
return {"key": key, "value": value}
@@ -320,6 +337,7 @@ def update_config(
}
action = mapping.get(key)
if action is None:
raise HTTPException(status_code=400, detail="Unknown key")
@@ -338,7 +356,9 @@ def set_secret_mode(
_check_token(authorization)
assert _pm is not None
enabled = data.get("enabled")
delay = data.get("delay")
if enabled is None or delay is None:
raise HTTPException(status_code=400, detail="Missing fields")
cfg = _pm.config_manager
@@ -384,6 +404,7 @@ def select_fingerprint(
_check_token(authorization)
assert _pm is not None
fp = data.get("fingerprint")
if not fp:
raise HTTPException(status_code=400, detail="Missing fingerprint")
_pm.select_fingerprint(fp)
@@ -409,7 +430,9 @@ def get_totp_codes(authorization: str | None = Header(None)) -> dict:
codes = []
for idx, label, _u, _url, _arch in entries:
code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed)
rem = _pm.entry_manager.get_totp_time_remaining(idx)
codes.append(
{"id": idx, "label": label, "code": code, "seconds_remaining": rem}
)
@@ -433,6 +456,7 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
while True:
try:
note = _pm.notifications.get_nowait()
except queue.Empty:
break
notes.append({"level": note.level, "message": note.message})
@@ -461,10 +485,12 @@ def add_relay(data: dict, authorization: str | None = Header(None)) -> dict[str,
_check_token(authorization)
assert _pm is not None
url = data.get("url")
if not url:
raise HTTPException(status_code=400, detail="Missing url")
cfg = _pm.config_manager.load_config(require_pin=False)
relays = cfg.get("relays", [])
if url in relays:
raise HTTPException(status_code=400, detail="Relay already present")
relays.append(url)
@@ -480,6 +506,7 @@ def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str
assert _pm is not None
cfg = _pm.config_manager.load_config(require_pin=False)
relays = cfg.get("relays", [])
if not (1 <= idx <= len(relays)):
raise HTTPException(status_code=400, detail="Invalid index")
if len(relays) == 1:
@@ -546,9 +573,11 @@ async def import_vault(
assert _pm is not None
ctype = request.headers.get("content-type", "")
if ctype.startswith("multipart/form-data"):
form = await request.form()
file = form.get("file")
if file is None:
raise HTTPException(status_code=400, detail="Missing file")
data = await file.read()
@@ -562,6 +591,7 @@ async def import_vault(
else:
body = await request.json()
path = body.get("path")
if not path:
raise HTTPException(status_code=400, detail="Missing file or path")
_pm.handle_import_database(Path(path))
@@ -581,9 +611,11 @@ def backup_parent_seed(
assert _pm is not None
if not data.get("confirm"):
raise HTTPException(status_code=400, detail="Confirmation required")
path_str = data.get("path")
if not path_str:
raise HTTPException(status_code=400, detail="Missing path")
path = Path(path_str)
@@ -600,6 +632,7 @@ def change_password(
_check_token(authorization)
assert _pm is not None
_pm.change_password(data.get("old", ""), data.get("new", ""))
return {"status": "ok"}
@@ -611,6 +644,7 @@ def generate_password(
_check_token(authorization)
assert _pm is not None
length = int(data.get("length", 12))
policy_keys = [
"include_special_chars",
"allowed_special_chars",
@@ -622,6 +656,7 @@ def generate_password(
"min_special",
]
kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None}
util = UtilityService(_pm)
password = util.generate_password(length, **kwargs)
return {"password": password}
@@ -640,4 +675,5 @@ def lock_vault(authorization: str | None = Header(None)) -> dict[str, str]:
async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]:
_check_token(authorization)
asyncio.get_event_loop().call_soon(sys.exit, 0)
return {"status": "shutting down"}

View File

@@ -0,0 +1,42 @@
import importlib
from pathlib import Path
from types import SimpleNamespace
from fastapi.testclient import TestClient
import sys
sys.path.append(str(Path(__file__).resolve().parents[1]))
def test_rate_limit_exceeded(monkeypatch):
monkeypatch.setenv("SEEDPASS_RATE_LIMIT", "2")
monkeypatch.setenv("SEEDPASS_RATE_WINDOW", "60")
import seedpass.api as api
importlib.reload(api)
dummy = SimpleNamespace(
entry_manager=SimpleNamespace(
search_entries=lambda q: [
(1, "Site", "user", "url", False, SimpleNamespace(value="password"))
]
),
config_manager=SimpleNamespace(load_config=lambda require_pin=False: {}),
fingerprint_manager=SimpleNamespace(list_fingerprints=lambda: []),
nostr_client=SimpleNamespace(
key_manager=SimpleNamespace(get_npub=lambda: "np")
),
verify_password=lambda pw: True,
)
monkeypatch.setattr(api, "PasswordManager", lambda: dummy)
token = api.start_server()
client = TestClient(api.app)
headers = {"Authorization": f"Bearer {token}"}
for _ in range(2):
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
assert res.status_code == 200
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
assert res.status_code == 429