diff --git a/src/seedpass/api.py b/src/seedpass/api.py index 4a8e13f..62f9bd6 100644 --- a/src/seedpass/api.py +++ b/src/seedpass/api.py @@ -36,35 +36,39 @@ _RATE_LIMIT_STR = f"{_RATE_LIMIT}/{_RATE_WINDOW} seconds" limiter = Limiter(key_func=get_remote_address, default_limits=[_RATE_LIMIT_STR]) app = FastAPI() -_pm: Optional[PasswordManager] = None -_token: str = "" -_jwt_secret: str = "" + +def _get_pm(request: Request) -> PasswordManager: + pm = getattr(request.app.state, "pm", None) + assert pm is not None + return pm -def _check_token(auth: str | None) -> None: +def _check_token(request: Request, auth: str | None) -> None: if auth is None or not auth.startswith("Bearer "): raise HTTPException(status_code=401, detail="Unauthorized") token = auth.split(" ", 1)[1] + jwt_secret = getattr(request.app.state, "jwt_secret", "") + token_hash = getattr(request.app.state, "token_hash", "") try: - jwt.decode(token, _jwt_secret, algorithms=["HS256"]) + jwt.decode(token, jwt_secret, algorithms=["HS256"]) except jwt.ExpiredSignatureError: raise HTTPException(status_code=401, detail="Token expired") except jwt.InvalidTokenError: raise HTTPException(status_code=401, detail="Unauthorized") - if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), _token): + if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), token_hash): raise HTTPException(status_code=401, detail="Unauthorized") -def _reload_relays(relays: list[str]) -> None: +def _reload_relays(request: Request, relays: list[str]) -> None: """Reload the Nostr client with a new relay list.""" - assert _pm is not None + pm = _get_pm(request) try: - _pm.nostr_client.close_client_pool() + pm.nostr_client.close_client_pool() except Exception: pass try: - _pm.nostr_client.relays = relays - _pm.nostr_client.initialize_client_pool() + pm.nostr_client.relays = relays + pm.nostr_client.initialize_client_pool() except Exception: pass @@ -77,15 +81,15 @@ def start_server(fingerprint: str | None = None) -> str: fingerprint: Optional seed profile fingerprint to select before starting the server. """ - global _pm, _token, _jwt_secret if fingerprint is None: - _pm = PasswordManager() + pm = PasswordManager() else: - _pm = PasswordManager(fingerprint=fingerprint) - _jwt_secret = secrets.token_urlsafe(32) + pm = PasswordManager(fingerprint=fingerprint) + app.state.pm = pm + app.state.jwt_secret = secrets.token_urlsafe(32) payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)} - raw_token = jwt.encode(payload, _jwt_secret, algorithm="HS256") - _token = hashlib.sha256(raw_token.encode()).hexdigest() + raw_token = jwt.encode(payload, app.state.jwt_secret, algorithm="HS256") + app.state.token_hash = hashlib.sha256(raw_token.encode()).hexdigest() if not getattr(app.state, "limiter", None): app.state.limiter = limiter app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) @@ -105,30 +109,32 @@ def start_server(fingerprint: str | None = None) -> str: return raw_token -def _require_password(password: str | None) -> None: - assert _pm is not None - if password is None or not _pm.verify_password(password): +def _require_password(request: Request, password: str | None) -> None: + pm = _get_pm(request) + if password is None or not pm.verify_password(password): raise HTTPException(status_code=401, detail="Invalid password") -def _validate_encryption_path(path: Path) -> Path: +def _validate_encryption_path(request: Request, path: Path) -> Path: """Validate and normalize ``path`` within the active fingerprint directory. Returns the resolved absolute path if validation succeeds. """ - assert _pm is not None + pm = _get_pm(request) try: - return _pm.encryption_manager.resolve_relative_path(path) + return pm.encryption_manager.resolve_relative_path(path) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @app.get("/api/v1/entry") -def search_entry(query: str, authorization: str | None = Header(None)) -> List[Any]: - _check_token(authorization) - assert _pm is not None - results = _pm.entry_manager.search_entries(query) +def search_entry( + request: Request, query: str, authorization: str | None = Header(None) +) -> List[Any]: + _check_token(request, authorization) + pm = _get_pm(request) + results = pm.entry_manager.search_entries(query) return [ { "id": idx, @@ -144,14 +150,15 @@ def search_entry(query: str, authorization: str | None = Header(None)) -> List[A @app.get("/api/v1/entry/{entry_id}") def get_entry( + request: Request, entry_id: int, authorization: str | None = Header(None), password: str | None = Header(None, alias="X-SeedPass-Password"), ) -> Any: - _check_token(authorization) - _require_password(password) - assert _pm is not None - entry = _pm.entry_manager.retrieve_entry(entry_id) + _check_token(request, authorization) + _require_password(request, password) + pm = _get_pm(request) + entry = pm.entry_manager.retrieve_entry(entry_id) if entry is None: raise HTTPException(status_code=404, detail="Not found") return entry @@ -159,6 +166,7 @@ def get_entry( @app.post("/api/v1/entry") def create_entry( + request: Request, entry: dict, authorization: str | None = Header(None), ) -> dict[str, Any]: @@ -168,8 +176,8 @@ def create_entry( on, the corresponding entry type is created. When omitted or set to ``password`` the behaviour matches the legacy password-entry API. """ - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) etype = (entry.get("type") or entry.get("kind") or "password").lower() @@ -186,7 +194,7 @@ def create_entry( ] kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None} - index = _pm.entry_manager.add_entry( + index = pm.entry_manager.add_entry( entry.get("label"), int(entry.get("length", 12)), entry.get("username"), @@ -196,11 +204,11 @@ def create_entry( return {"id": index} if etype == "totp": - index = _pm.entry_manager.get_next_index() + index = pm.entry_manager.get_next_index() - uri = _pm.entry_manager.add_totp( + uri = pm.entry_manager.add_totp( entry.get("label"), - _pm.parent_seed, + pm.parent_seed, secret=entry.get("secret"), index=entry.get("index"), period=int(entry.get("period", 30)), @@ -211,9 +219,9 @@ def create_entry( return {"id": index, "uri": uri} if etype == "ssh": - index = _pm.entry_manager.add_ssh_key( + index = pm.entry_manager.add_ssh_key( entry.get("label"), - _pm.parent_seed, + pm.parent_seed, index=entry.get("index"), notes=entry.get("notes", ""), archived=entry.get("archived", False), @@ -221,9 +229,9 @@ def create_entry( return {"id": index} if etype == "pgp": - index = _pm.entry_manager.add_pgp_key( + index = pm.entry_manager.add_pgp_key( entry.get("label"), - _pm.parent_seed, + pm.parent_seed, index=entry.get("index"), key_type=entry.get("key_type", "ed25519"), user_id=entry.get("user_id", ""), @@ -233,9 +241,9 @@ def create_entry( return {"id": index} if etype == "nostr": - index = _pm.entry_manager.add_nostr_key( + index = pm.entry_manager.add_nostr_key( entry.get("label"), - _pm.parent_seed, + pm.parent_seed, index=entry.get("index"), notes=entry.get("notes", ""), archived=entry.get("archived", False), @@ -243,7 +251,7 @@ def create_entry( return {"id": index} if etype == "key_value": - index = _pm.entry_manager.add_key_value( + index = pm.entry_manager.add_key_value( entry.get("label"), entry.get("key"), entry.get("value"), @@ -253,13 +261,13 @@ def create_entry( if etype in {"seed", "managed_account"}: func = ( - _pm.entry_manager.add_seed + pm.entry_manager.add_seed if etype == "seed" - else _pm.entry_manager.add_managed_account + else pm.entry_manager.add_managed_account ) index = func( entry.get("label"), - _pm.parent_seed, + pm.parent_seed, index=entry.get("index"), notes=entry.get("notes", ""), ) @@ -270,6 +278,7 @@ def create_entry( @app.put("/api/v1/entry/{entry_id}") def update_entry( + request: Request, entry_id: int, entry: dict, authorization: str | None = Header(None), @@ -279,10 +288,10 @@ def update_entry( Additional fields like ``period``, ``digits`` and ``value`` are forwarded for specialized entry types (e.g. TOTP or key/value entries). """ - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) try: - _pm.entry_manager.modify_entry( + pm.entry_manager.modify_entry( entry_id, username=entry.get("username"), url=entry.get("url"), @@ -300,31 +309,33 @@ def update_entry( @app.post("/api/v1/entry/{entry_id}/archive") def archive_entry( - entry_id: int, authorization: str | None = Header(None) + request: Request, entry_id: int, authorization: str | None = Header(None) ) -> dict[str, str]: """Archive an entry.""" - _check_token(authorization) - assert _pm is not None - _pm.entry_manager.archive_entry(entry_id) + _check_token(request, authorization) + pm = _get_pm(request) + pm.entry_manager.archive_entry(entry_id) return {"status": "archived"} @app.post("/api/v1/entry/{entry_id}/unarchive") def unarchive_entry( - entry_id: int, authorization: str | None = Header(None) + request: Request, entry_id: int, authorization: str | None = Header(None) ) -> dict[str, str]: """Restore an archived entry.""" - _check_token(authorization) - assert _pm is not None - _pm.entry_manager.restore_entry(entry_id) + _check_token(request, authorization) + pm = _get_pm(request) + pm.entry_manager.restore_entry(entry_id) return {"status": "active"} @app.get("/api/v1/config/{key}") -def get_config(key: str, authorization: str | None = Header(None)) -> Any: - _check_token(authorization) - assert _pm is not None - value = _pm.config_manager.load_config(require_pin=False).get(key) +def get_config( + request: Request, key: str, authorization: str | None = Header(None) +) -> Any: + _check_token(request, authorization) + pm = _get_pm(request) + value = pm.config_manager.load_config(require_pin=False).get(key) if value is None: raise HTTPException(status_code=404, detail="Not found") @@ -333,12 +344,15 @@ def get_config(key: str, authorization: str | None = Header(None)) -> Any: @app.put("/api/v1/config/{key}") def update_config( - key: str, data: dict, authorization: str | None = Header(None) + request: Request, + key: str, + data: dict, + authorization: str | None = Header(None), ) -> dict[str, str]: """Update a configuration setting.""" - _check_token(authorization) - assert _pm is not None - cfg = _pm.config_manager + _check_token(request, authorization) + pm = _get_pm(request) + cfg = pm.config_manager mapping = { "relays": lambda v: cfg.set_relays(v, require_pin=False), "pin": cfg.set_pin, @@ -364,96 +378,102 @@ def update_config( @app.post("/api/v1/secret-mode") def set_secret_mode( - data: dict, authorization: str | None = Header(None) + request: Request, data: dict, authorization: str | None = Header(None) ) -> dict[str, str]: """Enable/disable secret mode and set the clipboard delay.""" - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) enabled = data.get("enabled") delay = data.get("delay") if enabled is None or delay is None: raise HTTPException(status_code=400, detail="Missing fields") - cfg = _pm.config_manager + cfg = pm.config_manager cfg.set_secret_mode_enabled(bool(enabled)) cfg.set_clipboard_clear_delay(int(delay)) - _pm.secret_mode_enabled = bool(enabled) - _pm.clipboard_clear_delay = int(delay) + pm.secret_mode_enabled = bool(enabled) + pm.clipboard_clear_delay = int(delay) return {"status": "ok"} @app.get("/api/v1/fingerprint") -def list_fingerprints(authorization: str | None = Header(None)) -> List[str]: - _check_token(authorization) - assert _pm is not None - return _pm.fingerprint_manager.list_fingerprints() +def list_fingerprints( + request: Request, authorization: str | None = Header(None) +) -> List[str]: + _check_token(request, authorization) + pm = _get_pm(request) + return pm.fingerprint_manager.list_fingerprints() @app.post("/api/v1/fingerprint") -def add_fingerprint(authorization: str | None = Header(None)) -> dict[str, str]: +def add_fingerprint( + request: Request, authorization: str | None = Header(None) +) -> dict[str, str]: """Create a new seed profile.""" - _check_token(authorization) - assert _pm is not None - _pm.add_new_fingerprint() + _check_token(request, authorization) + pm = _get_pm(request) + pm.add_new_fingerprint() return {"status": "ok"} @app.delete("/api/v1/fingerprint/{fingerprint}") def remove_fingerprint( - fingerprint: str, authorization: str | None = Header(None) + request: Request, fingerprint: str, authorization: str | None = Header(None) ) -> dict[str, str]: """Remove a seed profile.""" - _check_token(authorization) - assert _pm is not None - _pm.fingerprint_manager.remove_fingerprint(fingerprint) + _check_token(request, authorization) + pm = _get_pm(request) + pm.fingerprint_manager.remove_fingerprint(fingerprint) return {"status": "deleted"} @app.post("/api/v1/fingerprint/select") def select_fingerprint( - data: dict, authorization: str | None = Header(None) + request: Request, data: dict, authorization: str | None = Header(None) ) -> dict[str, str]: """Switch the active seed profile.""" - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) fp = data.get("fingerprint") if not fp: raise HTTPException(status_code=400, detail="Missing fingerprint") - _pm.select_fingerprint(fp) + pm.select_fingerprint(fp) return {"status": "ok"} @app.get("/api/v1/totp/export") def export_totp( + request: Request, authorization: str | None = Header(None), password: str | None = Header(None, alias="X-SeedPass-Password"), ) -> dict: """Return all stored TOTP entries in JSON format.""" - _check_token(authorization) - _require_password(password) - assert _pm is not None - return _pm.entry_manager.export_totp_entries(_pm.parent_seed) + _check_token(request, authorization) + _require_password(request, password) + pm = _get_pm(request) + return pm.entry_manager.export_totp_entries(pm.parent_seed) @app.get("/api/v1/totp") def get_totp_codes( + request: Request, authorization: str | None = Header(None), password: str | None = Header(None, alias="X-SeedPass-Password"), ) -> dict: """Return active TOTP codes with remaining seconds.""" - _check_token(authorization) - _require_password(password) - assert _pm is not None - entries = _pm.entry_manager.list_entries( + _check_token(request, authorization) + _require_password(request, password) + pm = _get_pm(request) + entries = pm.entry_manager.list_entries( filter_kind=EntryType.TOTP.value, include_archived=False ) codes = [] for idx, label, _u, _url, _arch in entries: - code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed) + code = pm.entry_manager.get_totp_code(idx, pm.parent_seed) - rem = _pm.entry_manager.get_totp_time_remaining(idx) + rem = pm.entry_manager.get_totp_time_remaining(idx) codes.append( {"id": idx, "label": label, "code": code, "seconds_remaining": rem} @@ -462,23 +482,26 @@ def get_totp_codes( @app.get("/api/v1/stats") -def get_profile_stats(authorization: str | None = Header(None)) -> dict: +def get_profile_stats( + request: Request, authorization: str | None = Header(None) +) -> dict: """Return statistics about the active seed profile.""" - _check_token(authorization) - assert _pm is not None - return _pm.get_profile_stats() + _check_token(request, authorization) + pm = _get_pm(request) + return pm.get_profile_stats() @app.get("/api/v1/notifications") -def get_notifications(authorization: str | None = Header(None)) -> List[dict]: +def get_notifications( + request: Request, authorization: str | None = Header(None) +) -> List[dict]: """Return and clear queued notifications.""" - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) notes = [] while True: try: - note = _pm.notifications.get_nowait() - + note = pm.notifications.get_nowait() except queue.Empty: break notes.append({"level": note.level, "message": note.message}) @@ -486,47 +509,51 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]: @app.get("/api/v1/nostr/pubkey") -def get_nostr_pubkey(authorization: str | None = Header(None)) -> Any: - _check_token(authorization) - assert _pm is not None - return {"npub": _pm.nostr_client.key_manager.get_npub()} +def get_nostr_pubkey(request: Request, authorization: str | None = Header(None)) -> Any: + _check_token(request, authorization) + pm = _get_pm(request) + return {"npub": pm.nostr_client.key_manager.get_npub()} @app.get("/api/v1/relays") -def list_relays(authorization: str | None = Header(None)) -> dict: +def list_relays(request: Request, authorization: str | None = Header(None)) -> dict: """Return the configured Nostr relays.""" - _check_token(authorization) - assert _pm is not None - cfg = _pm.config_manager.load_config(require_pin=False) + _check_token(request, authorization) + pm = _get_pm(request) + cfg = pm.config_manager.load_config(require_pin=False) return {"relays": cfg.get("relays", [])} @app.post("/api/v1/relays") -def add_relay(data: dict, authorization: str | None = Header(None)) -> dict[str, str]: +def add_relay( + request: Request, data: dict, authorization: str | None = Header(None) +) -> dict[str, str]: """Add a relay URL to the configuration.""" - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) url = data.get("url") if not url: raise HTTPException(status_code=400, detail="Missing url") - cfg = _pm.config_manager.load_config(require_pin=False) + cfg = pm.config_manager.load_config(require_pin=False) relays = cfg.get("relays", []) if url in relays: raise HTTPException(status_code=400, detail="Relay already present") relays.append(url) - _pm.config_manager.set_relays(relays, require_pin=False) - _reload_relays(relays) + pm.config_manager.set_relays(relays, require_pin=False) + _reload_relays(request, relays) return {"status": "ok"} @app.delete("/api/v1/relays/{idx}") -def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str, str]: +def remove_relay( + request: Request, idx: int, authorization: str | None = Header(None) +) -> dict[str, str]: """Remove a relay by its index (1-based).""" - _check_token(authorization) - assert _pm is not None - cfg = _pm.config_manager.load_config(require_pin=False) + _check_token(request, authorization) + pm = _get_pm(request) + cfg = pm.config_manager.load_config(require_pin=False) relays = cfg.get("relays", []) if not (1 <= idx <= len(relays)): @@ -534,52 +561,59 @@ def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str if len(relays) == 1: raise HTTPException(status_code=400, detail="At least one relay required") relays.pop(idx - 1) - _pm.config_manager.set_relays(relays, require_pin=False) - _reload_relays(relays) + pm.config_manager.set_relays(relays, require_pin=False) + _reload_relays(request, relays) return {"status": "ok"} @app.post("/api/v1/relays/reset") -def reset_relays(authorization: str | None = Header(None)) -> dict[str, str]: +def reset_relays( + request: Request, authorization: str | None = Header(None) +) -> dict[str, str]: """Reset relay list to defaults.""" - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) from nostr.client import DEFAULT_RELAYS relays = list(DEFAULT_RELAYS) - _pm.config_manager.set_relays(relays, require_pin=False) - _reload_relays(relays) + pm.config_manager.set_relays(relays, require_pin=False) + _reload_relays(request, relays) return {"status": "ok"} @app.post("/api/v1/checksum/verify") -def verify_checksum(authorization: str | None = Header(None)) -> dict[str, str]: +def verify_checksum( + request: Request, authorization: str | None = Header(None) +) -> dict[str, str]: """Verify the SeedPass script checksum.""" - _check_token(authorization) - assert _pm is not None - _pm.handle_verify_checksum() + _check_token(request, authorization) + pm = _get_pm(request) + pm.handle_verify_checksum() return {"status": "ok"} @app.post("/api/v1/checksum/update") -def update_checksum(authorization: str | None = Header(None)) -> dict[str, str]: +def update_checksum( + request: Request, authorization: str | None = Header(None) +) -> dict[str, str]: """Regenerate the script checksum file.""" - _check_token(authorization) - assert _pm is not None - _pm.handle_update_script_checksum() + _check_token(request, authorization) + pm = _get_pm(request) + pm.handle_update_script_checksum() return {"status": "ok"} @app.post("/api/v1/vault/export") def export_vault( + request: Request, authorization: str | None = Header(None), password: str | None = Header(None, alias="X-SeedPass-Password"), ): """Export the vault and return the encrypted file.""" - _check_token(authorization) - _require_password(password) - assert _pm is not None - path = _pm.handle_export_database() + _check_token(request, authorization) + _require_password(request, password) + pm = _get_pm(request) + path = pm.handle_export_database() if path is None: raise HTTPException(status_code=500, detail="Export failed") data = Path(path).read_bytes() @@ -591,8 +625,8 @@ async def import_vault( request: Request, authorization: str | None = Header(None) ) -> dict[str, str]: """Import a vault backup from a file upload or a server path.""" - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) ctype = request.headers.get("content-type", "") @@ -607,7 +641,7 @@ async def import_vault( tmp.write(data) tmp_path = Path(tmp.name) try: - _pm.handle_import_database(tmp_path) + pm.handle_import_database(tmp_path) finally: os.unlink(tmp_path) else: @@ -617,28 +651,29 @@ async def import_vault( if not path_str: raise HTTPException(status_code=400, detail="Missing file or path") - path = _validate_encryption_path(Path(path_str)) + path = _validate_encryption_path(request, Path(path_str)) if not str(path).endswith(".json.enc"): raise HTTPException( status_code=400, detail="Selected file must be a '.json.enc' backup", ) - _pm.handle_import_database(path) - _pm.sync_vault() + pm.handle_import_database(path) + pm.sync_vault() return {"status": "ok"} @app.post("/api/v1/vault/backup-parent-seed") def backup_parent_seed( + request: Request, data: dict, authorization: str | None = Header(None), password: str | None = Header(None, alias="X-SeedPass-Password"), ) -> dict[str, str]: """Create an encrypted backup of the parent seed after confirmation.""" - _check_token(authorization) - _require_password(password) - assert _pm is not None + _check_token(request, authorization) + _require_password(request, password) + pm = _get_pm(request) if not data.get("confirm"): @@ -649,30 +684,30 @@ def backup_parent_seed( if not path_str: raise HTTPException(status_code=400, detail="Missing path") path = Path(path_str) - _validate_encryption_path(path) - _pm.encryption_manager.encrypt_and_save_file(_pm.parent_seed.encode("utf-8"), path) + _validate_encryption_path(request, path) + pm.encryption_manager.encrypt_and_save_file(pm.parent_seed.encode("utf-8"), path) return {"status": "saved", "path": str(path)} @app.post("/api/v1/change-password") def change_password( - data: dict, authorization: str | None = Header(None) + request: Request, data: dict, authorization: str | None = Header(None) ) -> dict[str, str]: """Change the master password for the active profile.""" - _check_token(authorization) - assert _pm is not None - _pm.change_password(data.get("old", ""), data.get("new", "")) + _check_token(request, authorization) + pm = _get_pm(request) + pm.change_password(data.get("old", ""), data.get("new", "")) return {"status": "ok"} @app.post("/api/v1/password") def generate_password( - data: dict, authorization: str | None = Header(None) + request: Request, data: dict, authorization: str | None = Header(None) ) -> dict[str, str]: """Generate a password using optional policy overrides.""" - _check_token(authorization) - assert _pm is not None + _check_token(request, authorization) + pm = _get_pm(request) length = int(data.get("length", 12)) policy_keys = [ @@ -687,23 +722,27 @@ def generate_password( ] kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None} - util = UtilityService(_pm) + util = UtilityService(pm) password = util.generate_password(length, **kwargs) return {"password": password} @app.post("/api/v1/vault/lock") -def lock_vault(authorization: str | None = Header(None)) -> dict[str, str]: +def lock_vault( + request: Request, authorization: str | None = Header(None) +) -> dict[str, str]: """Lock the vault and clear sensitive data from memory.""" - _check_token(authorization) - assert _pm is not None - _pm.lock_vault() + _check_token(request, authorization) + pm = _get_pm(request) + pm.lock_vault() return {"status": "locked"} @app.post("/api/v1/shutdown") -async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]: - _check_token(authorization) +async def shutdown_server( + request: Request, authorization: str | None = Header(None) +) -> dict[str, str]: + _check_token(request, authorization) asyncio.get_event_loop().call_soon(sys.exit, 0) return {"status": "shutting down"} diff --git a/src/seedpass/cli/api.py b/src/seedpass/cli/api.py index 8ebfe29..8c8b10a 100644 --- a/src/seedpass/cli/api.py +++ b/src/seedpass/cli/api.py @@ -25,7 +25,7 @@ def api_stop(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> N try: requests.post( f"http://{host}:{port}/api/v1/shutdown", - headers={"Authorization": f"Bearer {api_module._token}"}, + headers={"Authorization": f"Bearer {api_module.app.state.token_hash}"}, timeout=2, ) except Exception as exc: # pragma: no cover - best effort diff --git a/src/tests/test_api.py b/src/tests/test_api.py index 67d0551..9abf7f9 100644 --- a/src/tests/test_api.py +++ b/src/tests/test_api.py @@ -51,8 +51,8 @@ def client(monkeypatch): def test_token_hashed(client): _, token = client - assert api._token != token - assert api._token == hashlib.sha256(token.encode()).hexdigest() + assert api.app.state.token_hash != token + assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest() def test_cors_and_auth(client): @@ -158,7 +158,7 @@ def test_update_config(client): def set_timeout(val): called["val"] = val - api._pm.config_manager.set_inactivity_timeout = set_timeout + api.app.state.pm.config_manager.set_inactivity_timeout = set_timeout headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} res = cl.put( "/api/v1/config/inactivity_timeout", @@ -174,8 +174,9 @@ def test_update_config(client): def test_update_config_quick_unlock(client): cl, token = client called = {} - - api._pm.config_manager.set_quick_unlock = lambda v: called.setdefault("val", v) + api.app.state.pm.config_manager.set_quick_unlock = lambda v: called.setdefault( + "val", v + ) headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} res = cl.put( "/api/v1/config/quick_unlock", @@ -190,8 +191,7 @@ def test_update_config_quick_unlock(client): def test_change_password_route(client): cl, token = client called = {} - - api._pm.change_password = lambda o, n: called.setdefault("called", (o, n)) + api.app.state.pm.change_password = lambda o, n: called.setdefault("called", (o, n)) headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"} res = cl.post( "/api/v1/change-password", diff --git a/src/tests/test_api_new_endpoints.py b/src/tests/test_api_new_endpoints.py index 2e1b501..df5f29f 100644 --- a/src/tests/test_api_new_endpoints.py +++ b/src/tests/test_api_new_endpoints.py @@ -3,7 +3,6 @@ from pathlib import Path import os import base64 import pytest -from types import SimpleNamespace from seedpass import api from test_api import client @@ -25,10 +24,10 @@ def test_create_and_modify_totp_entry(client): def modify(idx, **kwargs): calls["modify"] = (idx, kwargs) - api._pm.entry_manager.add_totp = add_totp - api._pm.entry_manager.modify_entry = modify - api._pm.entry_manager.get_next_index = lambda: 5 - api._pm.parent_seed = "seed" + api.app.state.pm.entry_manager.add_totp = add_totp + api.app.state.pm.entry_manager.modify_entry = modify + api.app.state.pm.entry_manager.get_next_index = lambda: 5 + api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}"} res = cl.post( @@ -77,9 +76,9 @@ def test_create_and_modify_ssh_entry(client): def modify(idx, **kwargs): calls["modify"] = (idx, kwargs) - api._pm.entry_manager.add_ssh_key = add_ssh - api._pm.entry_manager.modify_entry = modify - api._pm.parent_seed = "seed" + api.app.state.pm.entry_manager.add_ssh_key = add_ssh + api.app.state.pm.entry_manager.modify_entry = modify + api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}"} res = cl.post( @@ -107,7 +106,7 @@ def test_update_entry_error(client): def modify(*a, **k): raise ValueError("nope") - api._pm.entry_manager.modify_entry = modify + api.app.state.pm.entry_manager.modify_entry = modify headers = {"Authorization": f"Bearer {token}"} res = cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers) assert res.status_code == 400 @@ -121,7 +120,7 @@ def test_update_config_secret_mode(client): def set_secret(val): called["val"] = val - api._pm.config_manager.set_secret_mode_enabled = set_secret + api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret headers = {"Authorization": f"Bearer {token}"} res = cl.put( "/api/v1/config/secret_mode_enabled", @@ -135,8 +134,8 @@ def test_update_config_secret_mode(client): def test_totp_export_endpoint(client): cl, token = client - api._pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]} - api._pm.parent_seed = "seed" + api.app.state.pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]} + api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"} res = cl.get("/api/v1/totp/export", headers=headers) assert res.status_code == 200 @@ -145,10 +144,12 @@ def test_totp_export_endpoint(client): def test_totp_codes_endpoint(client): cl, token = client - api._pm.entry_manager.list_entries = lambda **kw: [(0, "Email", None, None, False)] - api._pm.entry_manager.get_totp_code = lambda i, s: "123456" - api._pm.entry_manager.get_totp_time_remaining = lambda i: 30 - api._pm.parent_seed = "seed" + api.app.state.pm.entry_manager.list_entries = lambda **kw: [ + (0, "Email", None, None, False) + ] + api.app.state.pm.entry_manager.get_totp_code = lambda i, s: "123456" + api.app.state.pm.entry_manager.get_totp_time_remaining = lambda i: 30 + api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"} res = cl.get("/api/v1/totp", headers=headers) assert res.status_code == 200 @@ -169,11 +170,11 @@ def test_fingerprint_endpoints(client): cl, token = client calls = {} - api._pm.add_new_fingerprint = lambda: calls.setdefault("add", True) - api._pm.fingerprint_manager.remove_fingerprint = lambda fp: calls.setdefault( - "remove", fp + api.app.state.pm.add_new_fingerprint = lambda: calls.setdefault("add", True) + api.app.state.pm.fingerprint_manager.remove_fingerprint = ( + lambda fp: calls.setdefault("remove", fp) ) - api._pm.select_fingerprint = lambda fp: calls.setdefault("select", fp) + api.app.state.pm.select_fingerprint = lambda fp: calls.setdefault("select", fp) headers = {"Authorization": f"Bearer {token}"} @@ -201,8 +202,10 @@ def test_checksum_endpoints(client): cl, token = client calls = {} - api._pm.handle_verify_checksum = lambda: calls.setdefault("verify", True) - api._pm.handle_update_script_checksum = lambda: calls.setdefault("update", True) + api.app.state.pm.handle_verify_checksum = lambda: calls.setdefault("verify", True) + api.app.state.pm.handle_update_script_checksum = lambda: calls.setdefault( + "update", True + ) headers = {"Authorization": f"Bearer {token}"} @@ -224,9 +227,11 @@ def test_vault_import_via_path(client, tmp_path): def import_db(path): called["path"] = path - api._pm.handle_import_database = import_db - api._pm.sync_vault = lambda: called.setdefault("sync", True) - api._pm.encryption_manager = SimpleNamespace(resolve_relative_path=lambda p: p) + api.app.state.pm.handle_import_database = import_db + api.app.state.pm.sync_vault = lambda: called.setdefault("sync", True) + api.app.state.pm.encryption_manager = SimpleNamespace( + resolve_relative_path=lambda p: p + ) file_path = tmp_path / "b.json.enc" file_path.write_text("{}") @@ -249,8 +254,8 @@ def test_vault_import_via_upload(client, tmp_path): def import_db(path): called["path"] = path - api._pm.handle_import_database = import_db - api._pm.sync_vault = lambda: called.setdefault("sync", True) + api.app.state.pm.handle_import_database = import_db + api.app.state.pm.sync_vault = lambda: called.setdefault("sync", True) file_path = tmp_path / "c.json" file_path.write_text("{}") @@ -269,9 +274,11 @@ def test_vault_import_via_upload(client, tmp_path): def test_vault_import_invalid_extension(client): cl, token = client - api._pm.handle_import_database = lambda path: None - api._pm.sync_vault = lambda: None - api._pm.encryption_manager = SimpleNamespace(resolve_relative_path=lambda p: p) + api.app.state.pm.handle_import_database = lambda path: None + api.app.state.pm.sync_vault = lambda: None + api.app.state.pm.encryption_manager = SimpleNamespace( + resolve_relative_path=lambda p: p + ) headers = {"Authorization": f"Bearer {token}"} res = cl.post( @@ -285,9 +292,9 @@ def test_vault_import_invalid_extension(client): def test_vault_import_path_traversal_blocked(client, tmp_path): cl, token = client key = base64.urlsafe_b64encode(os.urandom(32)) - api._pm.encryption_manager = EncryptionManager(key, tmp_path) - api._pm.handle_import_database = lambda path: None - api._pm.sync_vault = lambda: None + api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path) + api.app.state.pm.handle_import_database = lambda path: None + api.app.state.pm.sync_vault = lambda: None headers = {"Authorization": f"Bearer {token}"} res = cl.post( @@ -304,20 +311,22 @@ def test_vault_lock_endpoint(client): def lock(): called["locked"] = True - api._pm.locked = True + api.app.state.pm.locked = True - api._pm.lock_vault = lock - api._pm.locked = False + api.app.state.pm.lock_vault = lock + api.app.state.pm.locked = False headers = {"Authorization": f"Bearer {token}"} res = cl.post("/api/v1/vault/lock", headers=headers) assert res.status_code == 200 assert res.json() == {"status": "locked"} assert called.get("locked") is True - assert api._pm.locked is True - api._pm.unlock_vault = lambda pw: setattr(api._pm, "locked", False) - api._pm.unlock_vault("pw") - assert api._pm.locked is False + assert api.app.state.pm.locked is True + api.app.state.pm.unlock_vault = lambda pw: setattr( + api.app.state.pm, "locked", False + ) + api.app.state.pm.unlock_vault("pw") + assert api.app.state.pm.locked is False def test_secret_mode_endpoint(client): @@ -330,8 +339,8 @@ def test_secret_mode_endpoint(client): def set_delay(val): called.setdefault("delay", val) - api._pm.config_manager.set_secret_mode_enabled = set_secret - api._pm.config_manager.set_clipboard_clear_delay = set_delay + api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret + api.app.state.pm.config_manager.set_clipboard_clear_delay = set_delay headers = {"Authorization": f"Bearer {token}"} res = cl.post( @@ -350,7 +359,7 @@ def test_vault_export_endpoint(client, tmp_path): out = tmp_path / "out.json" out.write_text("data") - api._pm.handle_export_database = lambda: out + api.app.state.pm.handle_export_database = lambda: out headers = { "Authorization": f"Bearer {token}", @@ -366,9 +375,9 @@ def test_vault_export_endpoint(client, tmp_path): def test_backup_parent_seed_endpoint(client, tmp_path): cl, token = client - api._pm.parent_seed = "seed" + api.app.state.pm.parent_seed = "seed" called = {} - api._pm.encryption_manager = SimpleNamespace( + api.app.state.pm.encryption_manager = SimpleNamespace( encrypt_and_save_file=lambda data, path: called.setdefault("path", path), resolve_relative_path=lambda p: p, ) @@ -396,9 +405,9 @@ def test_backup_parent_seed_endpoint(client, tmp_path): def test_backup_parent_seed_path_traversal_blocked(client, tmp_path): cl, token = client - api._pm.parent_seed = "seed" + api.app.state.pm.parent_seed = "seed" key = base64.urlsafe_b64encode(os.urandom(32)) - api._pm.encryption_manager = EncryptionManager(key, tmp_path) + api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path) headers = { "Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw", @@ -424,8 +433,8 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch): def set_relays(new, require_pin=False): called["set"] = new - api._pm.config_manager.load_config = load_config - api._pm.config_manager.set_relays = set_relays + api.app.state.pm.config_manager.load_config = load_config + api.app.state.pm.config_manager.set_relays = set_relays monkeypatch.setattr( NostrClient, "initialize_client_pool", @@ -434,8 +443,8 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch): monkeypatch.setattr( nostr_client, "close_client_pool", lambda: called.setdefault("close", True) ) - api._pm.nostr_client = nostr_client - api._pm.nostr_client.relays = relays.copy() + api.app.state.pm.nostr_client = nostr_client + api.app.state.pm.nostr_client.relays = relays.copy() headers = {"Authorization": f"Bearer {token}"} @@ -447,7 +456,7 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch): assert res.status_code == 200 assert called["set"] == ["wss://a", "wss://b", "wss://c"] - api._pm.config_manager.load_config = lambda require_pin=False: { + api.app.state.pm.config_manager.load_config = lambda require_pin=False: { "relays": ["wss://a", "wss://b", "wss://c"] } res = cl.delete("/api/v1/relays/2", headers=headers) @@ -457,7 +466,7 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch): res = cl.post("/api/v1/relays/reset", headers=headers) assert res.status_code == 200 assert called.get("init") is True - assert api._pm.nostr_client.relays == list(DEFAULT_RELAYS) + assert api.app.state.pm.nostr_client.relays == list(DEFAULT_RELAYS) def test_generate_password_no_special_chars(client): @@ -471,8 +480,10 @@ def test_generate_password_no_special_chars(client): def derive_entropy(self, index: int, bytes_len: int, app_no: int = 32) -> bytes: return bytes(range(bytes_len)) - api._pm.password_generator = PasswordGenerator(DummyEnc(), "seed", DummyBIP85()) - api._pm.parent_seed = "seed" + api.app.state.pm.password_generator = PasswordGenerator( + DummyEnc(), "seed", DummyBIP85() + ) + api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}"} res = cl.post( @@ -496,8 +507,10 @@ def test_generate_password_allowed_chars(client): def derive_entropy(self, index: int, bytes_len: int, app_no: int = 32) -> bytes: return bytes((index + i) % 256 for i in range(bytes_len)) - api._pm.password_generator = PasswordGenerator(DummyEnc(), "seed", DummyBIP85()) - api._pm.parent_seed = "seed" + api.app.state.pm.password_generator = PasswordGenerator( + DummyEnc(), "seed", DummyBIP85() + ) + api.app.state.pm.parent_seed = "seed" headers = {"Authorization": f"Bearer {token}"} allowed = "@$" diff --git a/src/tests/test_api_notifications.py b/src/tests/test_api_notifications.py index e0805a9..aefbd7d 100644 --- a/src/tests/test_api_notifications.py +++ b/src/tests/test_api_notifications.py @@ -6,40 +6,42 @@ import seedpass.api as api def test_notifications_endpoint(client): cl, token = client - api._pm.notifications = queue.Queue() - api._pm.notifications.put(SimpleNamespace(message="m1", level="INFO")) - api._pm.notifications.put(SimpleNamespace(message="m2", level="WARNING")) + api.app.state.pm.notifications = queue.Queue() + api.app.state.pm.notifications.put(SimpleNamespace(message="m1", level="INFO")) + api.app.state.pm.notifications.put(SimpleNamespace(message="m2", level="WARNING")) res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) assert res.status_code == 200 assert res.json() == [ {"level": "INFO", "message": "m1"}, {"level": "WARNING", "message": "m2"}, ] - assert api._pm.notifications.empty() + assert api.app.state.pm.notifications.empty() def test_notifications_endpoint_clears_queue(client): cl, token = client - api._pm.notifications = queue.Queue() - api._pm.notifications.put(SimpleNamespace(message="hi", level="INFO")) + api.app.state.pm.notifications = queue.Queue() + api.app.state.pm.notifications.put(SimpleNamespace(message="hi", level="INFO")) res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) assert res.status_code == 200 assert res.json() == [{"level": "INFO", "message": "hi"}] - assert api._pm.notifications.empty() + assert api.app.state.pm.notifications.empty() res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) assert res.json() == [] def test_notifications_endpoint_does_not_clear_current(client): cl, token = client - api._pm.notifications = queue.Queue() + api.app.state.pm.notifications = queue.Queue() msg = SimpleNamespace(message="keep", level="INFO") - api._pm.notifications.put(msg) - api._pm._current_notification = msg - api._pm.get_current_notification = lambda: api._pm._current_notification + api.app.state.pm.notifications.put(msg) + api.app.state.pm._current_notification = msg + api.app.state.pm.get_current_notification = ( + lambda: api.app.state.pm._current_notification + ) res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"}) assert res.status_code == 200 assert res.json() == [{"level": "INFO", "message": "keep"}] - assert api._pm.notifications.empty() - assert api._pm.get_current_notification() is msg + assert api.app.state.pm.notifications.empty() + assert api.app.state.pm.get_current_notification() is msg diff --git a/src/tests/test_api_profile_stats.py b/src/tests/test_api_profile_stats.py index 153dbd3..a3c62dc 100644 --- a/src/tests/test_api_profile_stats.py +++ b/src/tests/test_api_profile_stats.py @@ -7,7 +7,7 @@ def test_profile_stats_endpoint(client): # monkeypatch set _pm.get_profile_stats after client fixture started import seedpass.api as api - api._pm.get_profile_stats = lambda: stats + api.app.state.pm.get_profile_stats = lambda: stats res = cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"}) assert res.status_code == 200 assert res.json() == stats