From e5f11581013064af07c64cbf0d6853a1b8dbd0ce Mon Sep 17 00:00:00 2001 From: thePR0M3TH3AN <53631862+PR0M3TH3AN@users.noreply.github.com> Date: Sun, 3 Aug 2025 08:41:22 -0400 Subject: [PATCH] Test API rate limiting --- src/requirements.txt | 1 + src/seedpass/api.py | 36 +++++++++++++++++++++++++++ src/tests/test_api_rate_limit.py | 42 ++++++++++++++++++++++++++++++++ 3 files changed, 79 insertions(+) create mode 100644 src/tests/test_api_rate_limit.py diff --git a/src/requirements.txt b/src/requirements.txt index e40f6a5..59ef524 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -37,3 +37,4 @@ argon2-cffi toga-core>=0.5.2 pillow toga-dummy>=0.5.2 # for headless GUI tests +slowapi diff --git a/src/seedpass/api.py b/src/seedpass/api.py index fc58d2c..e7cd0c0 100644 --- a/src/seedpass/api.py +++ b/src/seedpass/api.py @@ -17,11 +17,21 @@ import asyncio import sys from fastapi.middleware.cors import CORSMiddleware +from slowapi import Limiter, _rate_limit_exceeded_handler +from slowapi.errors import RateLimitExceeded +from slowapi.util import get_remote_address +from slowapi.middleware import SlowAPIMiddleware + from seedpass.core.manager import PasswordManager from seedpass.core.entry_types import EntryType from seedpass.core.api import UtilityService +_RATE_LIMIT = int(os.getenv("SEEDPASS_RATE_LIMIT", "100")) +_RATE_WINDOW = int(os.getenv("SEEDPASS_RATE_WINDOW", "60")) +_RATE_LIMIT_STR = f"{_RATE_LIMIT}/{_RATE_WINDOW} seconds" + +limiter = Limiter(key_func=get_remote_address, default_limits=[_RATE_LIMIT_STR]) app = FastAPI() _pm: Optional[PasswordManager] = None @@ -71,6 +81,10 @@ def start_server(fingerprint: str | None = None) -> str: _jwt_secret = secrets.token_urlsafe(32) payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)} _token = jwt.encode(payload, _jwt_secret, algorithm="HS256") + if not getattr(app.state, "limiter", None): + app.state.limiter = limiter + app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) + app.add_middleware(SlowAPIMiddleware) origins = [ o.strip() for o in os.getenv("SEEDPASS_CORS_ORIGINS", "").split(",") @@ -157,6 +171,7 @@ def create_entry( "min_special", ] kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None} + index = _pm.entry_manager.add_entry( entry.get("label"), int(entry.get("length", 12)), @@ -168,6 +183,7 @@ def create_entry( if etype == "totp": index = _pm.entry_manager.get_next_index() + uri = _pm.entry_manager.add_totp( entry.get("label"), _pm.parent_seed, @@ -295,6 +311,7 @@ def get_config(key: str, authorization: str | None = Header(None)) -> Any: _check_token(authorization) assert _pm is not None value = _pm.config_manager.load_config(require_pin=False).get(key) + if value is None: raise HTTPException(status_code=404, detail="Not found") return {"key": key, "value": value} @@ -320,6 +337,7 @@ def update_config( } action = mapping.get(key) + if action is None: raise HTTPException(status_code=400, detail="Unknown key") @@ -338,7 +356,9 @@ def set_secret_mode( _check_token(authorization) assert _pm is not None enabled = data.get("enabled") + delay = data.get("delay") + if enabled is None or delay is None: raise HTTPException(status_code=400, detail="Missing fields") cfg = _pm.config_manager @@ -384,6 +404,7 @@ def select_fingerprint( _check_token(authorization) assert _pm is not None fp = data.get("fingerprint") + if not fp: raise HTTPException(status_code=400, detail="Missing fingerprint") _pm.select_fingerprint(fp) @@ -409,7 +430,9 @@ def get_totp_codes(authorization: str | None = Header(None)) -> dict: codes = [] for idx, label, _u, _url, _arch in entries: code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed) + rem = _pm.entry_manager.get_totp_time_remaining(idx) + codes.append( {"id": idx, "label": label, "code": code, "seconds_remaining": rem} ) @@ -433,6 +456,7 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]: while True: try: note = _pm.notifications.get_nowait() + except queue.Empty: break notes.append({"level": note.level, "message": note.message}) @@ -461,10 +485,12 @@ def add_relay(data: dict, authorization: str | None = Header(None)) -> dict[str, _check_token(authorization) assert _pm is not None url = data.get("url") + if not url: raise HTTPException(status_code=400, detail="Missing url") cfg = _pm.config_manager.load_config(require_pin=False) relays = cfg.get("relays", []) + if url in relays: raise HTTPException(status_code=400, detail="Relay already present") relays.append(url) @@ -480,6 +506,7 @@ def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str assert _pm is not None cfg = _pm.config_manager.load_config(require_pin=False) relays = cfg.get("relays", []) + if not (1 <= idx <= len(relays)): raise HTTPException(status_code=400, detail="Invalid index") if len(relays) == 1: @@ -546,9 +573,11 @@ async def import_vault( assert _pm is not None ctype = request.headers.get("content-type", "") + if ctype.startswith("multipart/form-data"): form = await request.form() file = form.get("file") + if file is None: raise HTTPException(status_code=400, detail="Missing file") data = await file.read() @@ -562,6 +591,7 @@ async def import_vault( else: body = await request.json() path = body.get("path") + if not path: raise HTTPException(status_code=400, detail="Missing file or path") _pm.handle_import_database(Path(path)) @@ -581,9 +611,11 @@ def backup_parent_seed( assert _pm is not None if not data.get("confirm"): + raise HTTPException(status_code=400, detail="Confirmation required") path_str = data.get("path") + if not path_str: raise HTTPException(status_code=400, detail="Missing path") path = Path(path_str) @@ -600,6 +632,7 @@ def change_password( _check_token(authorization) assert _pm is not None _pm.change_password(data.get("old", ""), data.get("new", "")) + return {"status": "ok"} @@ -611,6 +644,7 @@ def generate_password( _check_token(authorization) assert _pm is not None length = int(data.get("length", 12)) + policy_keys = [ "include_special_chars", "allowed_special_chars", @@ -622,6 +656,7 @@ def generate_password( "min_special", ] kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None} + util = UtilityService(_pm) password = util.generate_password(length, **kwargs) return {"password": password} @@ -640,4 +675,5 @@ def lock_vault(authorization: str | None = Header(None)) -> dict[str, str]: async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]: _check_token(authorization) asyncio.get_event_loop().call_soon(sys.exit, 0) + return {"status": "shutting down"} diff --git a/src/tests/test_api_rate_limit.py b/src/tests/test_api_rate_limit.py new file mode 100644 index 0000000..ca42732 --- /dev/null +++ b/src/tests/test_api_rate_limit.py @@ -0,0 +1,42 @@ +import importlib +from pathlib import Path +from types import SimpleNamespace + +from fastapi.testclient import TestClient + +import sys + +sys.path.append(str(Path(__file__).resolve().parents[1])) + + +def test_rate_limit_exceeded(monkeypatch): + monkeypatch.setenv("SEEDPASS_RATE_LIMIT", "2") + monkeypatch.setenv("SEEDPASS_RATE_WINDOW", "60") + import seedpass.api as api + + importlib.reload(api) + + dummy = SimpleNamespace( + entry_manager=SimpleNamespace( + search_entries=lambda q: [ + (1, "Site", "user", "url", False, SimpleNamespace(value="password")) + ] + ), + config_manager=SimpleNamespace(load_config=lambda require_pin=False: {}), + fingerprint_manager=SimpleNamespace(list_fingerprints=lambda: []), + nostr_client=SimpleNamespace( + key_manager=SimpleNamespace(get_npub=lambda: "np") + ), + verify_password=lambda pw: True, + ) + monkeypatch.setattr(api, "PasswordManager", lambda: dummy) + token = api.start_server() + client = TestClient(api.app) + headers = {"Authorization": f"Bearer {token}"} + + for _ in range(2): + res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers) + assert res.status_code == 200 + + res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers) + assert res.status_code == 429