Merge pull request #769 from PR0M3TH3AN/codex/refactor-to-use-fastapi-dependencies

refactor: move api state to app
This commit is contained in:
thePR0M3TH3AN
2025-08-05 19:19:58 -04:00
committed by GitHub
6 changed files with 302 additions and 248 deletions

View File

@@ -36,35 +36,39 @@ _RATE_LIMIT_STR = f"{_RATE_LIMIT}/{_RATE_WINDOW} seconds"
limiter = Limiter(key_func=get_remote_address, default_limits=[_RATE_LIMIT_STR])
app = FastAPI()
_pm: Optional[PasswordManager] = None
_token: str = ""
_jwt_secret: str = ""
def _get_pm(request: Request) -> PasswordManager:
pm = getattr(request.app.state, "pm", None)
assert pm is not None
return pm
def _check_token(auth: str | None) -> None:
def _check_token(request: Request, auth: str | None) -> None:
if auth is None or not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = auth.split(" ", 1)[1]
jwt_secret = getattr(request.app.state, "jwt_secret", "")
token_hash = getattr(request.app.state, "token_hash", "")
try:
jwt.decode(token, _jwt_secret, algorithms=["HS256"])
jwt.decode(token, jwt_secret, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Unauthorized")
if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), _token):
if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), token_hash):
raise HTTPException(status_code=401, detail="Unauthorized")
def _reload_relays(relays: list[str]) -> None:
def _reload_relays(request: Request, relays: list[str]) -> None:
"""Reload the Nostr client with a new relay list."""
assert _pm is not None
pm = _get_pm(request)
try:
_pm.nostr_client.close_client_pool()
pm.nostr_client.close_client_pool()
except Exception:
pass
try:
_pm.nostr_client.relays = relays
_pm.nostr_client.initialize_client_pool()
pm.nostr_client.relays = relays
pm.nostr_client.initialize_client_pool()
except Exception:
pass
@@ -77,15 +81,15 @@ def start_server(fingerprint: str | None = None) -> str:
fingerprint:
Optional seed profile fingerprint to select before starting the server.
"""
global _pm, _token, _jwt_secret
if fingerprint is None:
_pm = PasswordManager()
pm = PasswordManager()
else:
_pm = PasswordManager(fingerprint=fingerprint)
_jwt_secret = secrets.token_urlsafe(32)
pm = PasswordManager(fingerprint=fingerprint)
app.state.pm = pm
app.state.jwt_secret = secrets.token_urlsafe(32)
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
raw_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
_token = hashlib.sha256(raw_token.encode()).hexdigest()
raw_token = jwt.encode(payload, app.state.jwt_secret, algorithm="HS256")
app.state.token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
if not getattr(app.state, "limiter", None):
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@@ -105,30 +109,32 @@ def start_server(fingerprint: str | None = None) -> str:
return raw_token
def _require_password(password: str | None) -> None:
assert _pm is not None
if password is None or not _pm.verify_password(password):
def _require_password(request: Request, password: str | None) -> None:
pm = _get_pm(request)
if password is None or not pm.verify_password(password):
raise HTTPException(status_code=401, detail="Invalid password")
def _validate_encryption_path(path: Path) -> Path:
def _validate_encryption_path(request: Request, path: Path) -> Path:
"""Validate and normalize ``path`` within the active fingerprint directory.
Returns the resolved absolute path if validation succeeds.
"""
assert _pm is not None
pm = _get_pm(request)
try:
return _pm.encryption_manager.resolve_relative_path(path)
return pm.encryption_manager.resolve_relative_path(path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/entry")
def search_entry(query: str, authorization: str | None = Header(None)) -> List[Any]:
_check_token(authorization)
assert _pm is not None
results = _pm.entry_manager.search_entries(query)
def search_entry(
request: Request, query: str, authorization: str | None = Header(None)
) -> List[Any]:
_check_token(request, authorization)
pm = _get_pm(request)
results = pm.entry_manager.search_entries(query)
return [
{
"id": idx,
@@ -144,14 +150,15 @@ def search_entry(query: str, authorization: str | None = Header(None)) -> List[A
@app.get("/api/v1/entry/{entry_id}")
def get_entry(
request: Request,
entry_id: int,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> Any:
_check_token(authorization)
_require_password(password)
assert _pm is not None
entry = _pm.entry_manager.retrieve_entry(entry_id)
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
entry = pm.entry_manager.retrieve_entry(entry_id)
if entry is None:
raise HTTPException(status_code=404, detail="Not found")
return entry
@@ -159,6 +166,7 @@ def get_entry(
@app.post("/api/v1/entry")
def create_entry(
request: Request,
entry: dict,
authorization: str | None = Header(None),
) -> dict[str, Any]:
@@ -168,8 +176,8 @@ def create_entry(
on, the corresponding entry type is created. When omitted or set to
``password`` the behaviour matches the legacy password-entry API.
"""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
etype = (entry.get("type") or entry.get("kind") or "password").lower()
@@ -186,7 +194,7 @@ def create_entry(
]
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
index = _pm.entry_manager.add_entry(
index = pm.entry_manager.add_entry(
entry.get("label"),
int(entry.get("length", 12)),
entry.get("username"),
@@ -196,11 +204,11 @@ def create_entry(
return {"id": index}
if etype == "totp":
index = _pm.entry_manager.get_next_index()
index = pm.entry_manager.get_next_index()
uri = _pm.entry_manager.add_totp(
uri = pm.entry_manager.add_totp(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
secret=entry.get("secret"),
index=entry.get("index"),
period=int(entry.get("period", 30)),
@@ -211,9 +219,9 @@ def create_entry(
return {"id": index, "uri": uri}
if etype == "ssh":
index = _pm.entry_manager.add_ssh_key(
index = pm.entry_manager.add_ssh_key(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
index=entry.get("index"),
notes=entry.get("notes", ""),
archived=entry.get("archived", False),
@@ -221,9 +229,9 @@ def create_entry(
return {"id": index}
if etype == "pgp":
index = _pm.entry_manager.add_pgp_key(
index = pm.entry_manager.add_pgp_key(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
index=entry.get("index"),
key_type=entry.get("key_type", "ed25519"),
user_id=entry.get("user_id", ""),
@@ -233,9 +241,9 @@ def create_entry(
return {"id": index}
if etype == "nostr":
index = _pm.entry_manager.add_nostr_key(
index = pm.entry_manager.add_nostr_key(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
index=entry.get("index"),
notes=entry.get("notes", ""),
archived=entry.get("archived", False),
@@ -243,7 +251,7 @@ def create_entry(
return {"id": index}
if etype == "key_value":
index = _pm.entry_manager.add_key_value(
index = pm.entry_manager.add_key_value(
entry.get("label"),
entry.get("key"),
entry.get("value"),
@@ -253,13 +261,13 @@ def create_entry(
if etype in {"seed", "managed_account"}:
func = (
_pm.entry_manager.add_seed
pm.entry_manager.add_seed
if etype == "seed"
else _pm.entry_manager.add_managed_account
else pm.entry_manager.add_managed_account
)
index = func(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
index=entry.get("index"),
notes=entry.get("notes", ""),
)
@@ -270,6 +278,7 @@ def create_entry(
@app.put("/api/v1/entry/{entry_id}")
def update_entry(
request: Request,
entry_id: int,
entry: dict,
authorization: str | None = Header(None),
@@ -279,10 +288,10 @@ def update_entry(
Additional fields like ``period``, ``digits`` and ``value`` are forwarded for
specialized entry types (e.g. TOTP or key/value entries).
"""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
try:
_pm.entry_manager.modify_entry(
pm.entry_manager.modify_entry(
entry_id,
username=entry.get("username"),
url=entry.get("url"),
@@ -300,31 +309,33 @@ def update_entry(
@app.post("/api/v1/entry/{entry_id}/archive")
def archive_entry(
entry_id: int, authorization: str | None = Header(None)
request: Request, entry_id: int, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Archive an entry."""
_check_token(authorization)
assert _pm is not None
_pm.entry_manager.archive_entry(entry_id)
_check_token(request, authorization)
pm = _get_pm(request)
pm.entry_manager.archive_entry(entry_id)
return {"status": "archived"}
@app.post("/api/v1/entry/{entry_id}/unarchive")
def unarchive_entry(
entry_id: int, authorization: str | None = Header(None)
request: Request, entry_id: int, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Restore an archived entry."""
_check_token(authorization)
assert _pm is not None
_pm.entry_manager.restore_entry(entry_id)
_check_token(request, authorization)
pm = _get_pm(request)
pm.entry_manager.restore_entry(entry_id)
return {"status": "active"}
@app.get("/api/v1/config/{key}")
def get_config(key: str, authorization: str | None = Header(None)) -> Any:
_check_token(authorization)
assert _pm is not None
value = _pm.config_manager.load_config(require_pin=False).get(key)
def get_config(
request: Request, key: str, authorization: str | None = Header(None)
) -> Any:
_check_token(request, authorization)
pm = _get_pm(request)
value = pm.config_manager.load_config(require_pin=False).get(key)
if value is None:
raise HTTPException(status_code=404, detail="Not found")
@@ -333,12 +344,15 @@ def get_config(key: str, authorization: str | None = Header(None)) -> Any:
@app.put("/api/v1/config/{key}")
def update_config(
key: str, data: dict, authorization: str | None = Header(None)
request: Request,
key: str,
data: dict,
authorization: str | None = Header(None),
) -> dict[str, str]:
"""Update a configuration setting."""
_check_token(authorization)
assert _pm is not None
cfg = _pm.config_manager
_check_token(request, authorization)
pm = _get_pm(request)
cfg = pm.config_manager
mapping = {
"relays": lambda v: cfg.set_relays(v, require_pin=False),
"pin": cfg.set_pin,
@@ -364,96 +378,102 @@ def update_config(
@app.post("/api/v1/secret-mode")
def set_secret_mode(
data: dict, authorization: str | None = Header(None)
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Enable/disable secret mode and set the clipboard delay."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
enabled = data.get("enabled")
delay = data.get("delay")
if enabled is None or delay is None:
raise HTTPException(status_code=400, detail="Missing fields")
cfg = _pm.config_manager
cfg = pm.config_manager
cfg.set_secret_mode_enabled(bool(enabled))
cfg.set_clipboard_clear_delay(int(delay))
_pm.secret_mode_enabled = bool(enabled)
_pm.clipboard_clear_delay = int(delay)
pm.secret_mode_enabled = bool(enabled)
pm.clipboard_clear_delay = int(delay)
return {"status": "ok"}
@app.get("/api/v1/fingerprint")
def list_fingerprints(authorization: str | None = Header(None)) -> List[str]:
_check_token(authorization)
assert _pm is not None
return _pm.fingerprint_manager.list_fingerprints()
def list_fingerprints(
request: Request, authorization: str | None = Header(None)
) -> List[str]:
_check_token(request, authorization)
pm = _get_pm(request)
return pm.fingerprint_manager.list_fingerprints()
@app.post("/api/v1/fingerprint")
def add_fingerprint(authorization: str | None = Header(None)) -> dict[str, str]:
def add_fingerprint(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Create a new seed profile."""
_check_token(authorization)
assert _pm is not None
_pm.add_new_fingerprint()
_check_token(request, authorization)
pm = _get_pm(request)
pm.add_new_fingerprint()
return {"status": "ok"}
@app.delete("/api/v1/fingerprint/{fingerprint}")
def remove_fingerprint(
fingerprint: str, authorization: str | None = Header(None)
request: Request, fingerprint: str, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Remove a seed profile."""
_check_token(authorization)
assert _pm is not None
_pm.fingerprint_manager.remove_fingerprint(fingerprint)
_check_token(request, authorization)
pm = _get_pm(request)
pm.fingerprint_manager.remove_fingerprint(fingerprint)
return {"status": "deleted"}
@app.post("/api/v1/fingerprint/select")
def select_fingerprint(
data: dict, authorization: str | None = Header(None)
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Switch the active seed profile."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
fp = data.get("fingerprint")
if not fp:
raise HTTPException(status_code=400, detail="Missing fingerprint")
_pm.select_fingerprint(fp)
pm.select_fingerprint(fp)
return {"status": "ok"}
@app.get("/api/v1/totp/export")
def export_totp(
request: Request,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> dict:
"""Return all stored TOTP entries in JSON format."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
return _pm.entry_manager.export_totp_entries(_pm.parent_seed)
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
return pm.entry_manager.export_totp_entries(pm.parent_seed)
@app.get("/api/v1/totp")
def get_totp_codes(
request: Request,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> dict:
"""Return active TOTP codes with remaining seconds."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
entries = _pm.entry_manager.list_entries(
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
entries = pm.entry_manager.list_entries(
filter_kind=EntryType.TOTP.value, include_archived=False
)
codes = []
for idx, label, _u, _url, _arch in entries:
code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed)
code = pm.entry_manager.get_totp_code(idx, pm.parent_seed)
rem = _pm.entry_manager.get_totp_time_remaining(idx)
rem = pm.entry_manager.get_totp_time_remaining(idx)
codes.append(
{"id": idx, "label": label, "code": code, "seconds_remaining": rem}
@@ -462,23 +482,26 @@ def get_totp_codes(
@app.get("/api/v1/stats")
def get_profile_stats(authorization: str | None = Header(None)) -> dict:
def get_profile_stats(
request: Request, authorization: str | None = Header(None)
) -> dict:
"""Return statistics about the active seed profile."""
_check_token(authorization)
assert _pm is not None
return _pm.get_profile_stats()
_check_token(request, authorization)
pm = _get_pm(request)
return pm.get_profile_stats()
@app.get("/api/v1/notifications")
def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
def get_notifications(
request: Request, authorization: str | None = Header(None)
) -> List[dict]:
"""Return and clear queued notifications."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
notes = []
while True:
try:
note = _pm.notifications.get_nowait()
note = pm.notifications.get_nowait()
except queue.Empty:
break
notes.append({"level": note.level, "message": note.message})
@@ -486,47 +509,51 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
@app.get("/api/v1/nostr/pubkey")
def get_nostr_pubkey(authorization: str | None = Header(None)) -> Any:
_check_token(authorization)
assert _pm is not None
return {"npub": _pm.nostr_client.key_manager.get_npub()}
def get_nostr_pubkey(request: Request, authorization: str | None = Header(None)) -> Any:
_check_token(request, authorization)
pm = _get_pm(request)
return {"npub": pm.nostr_client.key_manager.get_npub()}
@app.get("/api/v1/relays")
def list_relays(authorization: str | None = Header(None)) -> dict:
def list_relays(request: Request, authorization: str | None = Header(None)) -> dict:
"""Return the configured Nostr relays."""
_check_token(authorization)
assert _pm is not None
cfg = _pm.config_manager.load_config(require_pin=False)
_check_token(request, authorization)
pm = _get_pm(request)
cfg = pm.config_manager.load_config(require_pin=False)
return {"relays": cfg.get("relays", [])}
@app.post("/api/v1/relays")
def add_relay(data: dict, authorization: str | None = Header(None)) -> dict[str, str]:
def add_relay(
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Add a relay URL to the configuration."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
url = data.get("url")
if not url:
raise HTTPException(status_code=400, detail="Missing url")
cfg = _pm.config_manager.load_config(require_pin=False)
cfg = pm.config_manager.load_config(require_pin=False)
relays = cfg.get("relays", [])
if url in relays:
raise HTTPException(status_code=400, detail="Relay already present")
relays.append(url)
_pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(relays)
pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(request, relays)
return {"status": "ok"}
@app.delete("/api/v1/relays/{idx}")
def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str, str]:
def remove_relay(
request: Request, idx: int, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Remove a relay by its index (1-based)."""
_check_token(authorization)
assert _pm is not None
cfg = _pm.config_manager.load_config(require_pin=False)
_check_token(request, authorization)
pm = _get_pm(request)
cfg = pm.config_manager.load_config(require_pin=False)
relays = cfg.get("relays", [])
if not (1 <= idx <= len(relays)):
@@ -534,52 +561,59 @@ def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str
if len(relays) == 1:
raise HTTPException(status_code=400, detail="At least one relay required")
relays.pop(idx - 1)
_pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(relays)
pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(request, relays)
return {"status": "ok"}
@app.post("/api/v1/relays/reset")
def reset_relays(authorization: str | None = Header(None)) -> dict[str, str]:
def reset_relays(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Reset relay list to defaults."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
from nostr.client import DEFAULT_RELAYS
relays = list(DEFAULT_RELAYS)
_pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(relays)
pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(request, relays)
return {"status": "ok"}
@app.post("/api/v1/checksum/verify")
def verify_checksum(authorization: str | None = Header(None)) -> dict[str, str]:
def verify_checksum(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Verify the SeedPass script checksum."""
_check_token(authorization)
assert _pm is not None
_pm.handle_verify_checksum()
_check_token(request, authorization)
pm = _get_pm(request)
pm.handle_verify_checksum()
return {"status": "ok"}
@app.post("/api/v1/checksum/update")
def update_checksum(authorization: str | None = Header(None)) -> dict[str, str]:
def update_checksum(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Regenerate the script checksum file."""
_check_token(authorization)
assert _pm is not None
_pm.handle_update_script_checksum()
_check_token(request, authorization)
pm = _get_pm(request)
pm.handle_update_script_checksum()
return {"status": "ok"}
@app.post("/api/v1/vault/export")
def export_vault(
request: Request,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
):
"""Export the vault and return the encrypted file."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
path = _pm.handle_export_database()
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
path = pm.handle_export_database()
if path is None:
raise HTTPException(status_code=500, detail="Export failed")
data = Path(path).read_bytes()
@@ -591,8 +625,8 @@ async def import_vault(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Import a vault backup from a file upload or a server path."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
ctype = request.headers.get("content-type", "")
@@ -607,7 +641,7 @@ async def import_vault(
tmp.write(data)
tmp_path = Path(tmp.name)
try:
_pm.handle_import_database(tmp_path)
pm.handle_import_database(tmp_path)
finally:
os.unlink(tmp_path)
else:
@@ -617,28 +651,29 @@ async def import_vault(
if not path_str:
raise HTTPException(status_code=400, detail="Missing file or path")
path = _validate_encryption_path(Path(path_str))
path = _validate_encryption_path(request, Path(path_str))
if not str(path).endswith(".json.enc"):
raise HTTPException(
status_code=400,
detail="Selected file must be a '.json.enc' backup",
)
_pm.handle_import_database(path)
_pm.sync_vault()
pm.handle_import_database(path)
pm.sync_vault()
return {"status": "ok"}
@app.post("/api/v1/vault/backup-parent-seed")
def backup_parent_seed(
request: Request,
data: dict,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> dict[str, str]:
"""Create an encrypted backup of the parent seed after confirmation."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
if not data.get("confirm"):
@@ -649,30 +684,30 @@ def backup_parent_seed(
if not path_str:
raise HTTPException(status_code=400, detail="Missing path")
path = Path(path_str)
_validate_encryption_path(path)
_pm.encryption_manager.encrypt_and_save_file(_pm.parent_seed.encode("utf-8"), path)
_validate_encryption_path(request, path)
pm.encryption_manager.encrypt_and_save_file(pm.parent_seed.encode("utf-8"), path)
return {"status": "saved", "path": str(path)}
@app.post("/api/v1/change-password")
def change_password(
data: dict, authorization: str | None = Header(None)
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Change the master password for the active profile."""
_check_token(authorization)
assert _pm is not None
_pm.change_password(data.get("old", ""), data.get("new", ""))
_check_token(request, authorization)
pm = _get_pm(request)
pm.change_password(data.get("old", ""), data.get("new", ""))
return {"status": "ok"}
@app.post("/api/v1/password")
def generate_password(
data: dict, authorization: str | None = Header(None)
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Generate a password using optional policy overrides."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
length = int(data.get("length", 12))
policy_keys = [
@@ -687,23 +722,27 @@ def generate_password(
]
kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None}
util = UtilityService(_pm)
util = UtilityService(pm)
password = util.generate_password(length, **kwargs)
return {"password": password}
@app.post("/api/v1/vault/lock")
def lock_vault(authorization: str | None = Header(None)) -> dict[str, str]:
def lock_vault(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Lock the vault and clear sensitive data from memory."""
_check_token(authorization)
assert _pm is not None
_pm.lock_vault()
_check_token(request, authorization)
pm = _get_pm(request)
pm.lock_vault()
return {"status": "locked"}
@app.post("/api/v1/shutdown")
async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]:
_check_token(authorization)
async def shutdown_server(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
_check_token(request, authorization)
asyncio.get_event_loop().call_soon(sys.exit, 0)
return {"status": "shutting down"}

View File

@@ -25,7 +25,7 @@ def api_stop(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> N
try:
requests.post(
f"http://{host}:{port}/api/v1/shutdown",
headers={"Authorization": f"Bearer {api_module._token}"},
headers={"Authorization": f"Bearer {api_module.app.state.token_hash}"},
timeout=2,
)
except Exception as exc: # pragma: no cover - best effort

View File

@@ -51,8 +51,8 @@ def client(monkeypatch):
def test_token_hashed(client):
_, token = client
assert api._token != token
assert api._token == hashlib.sha256(token.encode()).hexdigest()
assert api.app.state.token_hash != token
assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest()
def test_cors_and_auth(client):
@@ -158,7 +158,7 @@ def test_update_config(client):
def set_timeout(val):
called["val"] = val
api._pm.config_manager.set_inactivity_timeout = set_timeout
api.app.state.pm.config_manager.set_inactivity_timeout = set_timeout
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.put(
"/api/v1/config/inactivity_timeout",
@@ -174,8 +174,9 @@ def test_update_config(client):
def test_update_config_quick_unlock(client):
cl, token = client
called = {}
api._pm.config_manager.set_quick_unlock = lambda v: called.setdefault("val", v)
api.app.state.pm.config_manager.set_quick_unlock = lambda v: called.setdefault(
"val", v
)
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.put(
"/api/v1/config/quick_unlock",
@@ -190,8 +191,7 @@ def test_update_config_quick_unlock(client):
def test_change_password_route(client):
cl, token = client
called = {}
api._pm.change_password = lambda o, n: called.setdefault("called", (o, n))
api.app.state.pm.change_password = lambda o, n: called.setdefault("called", (o, n))
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
res = cl.post(
"/api/v1/change-password",

View File

@@ -3,7 +3,6 @@ from pathlib import Path
import os
import base64
import pytest
from types import SimpleNamespace
from seedpass import api
from test_api import client
@@ -25,10 +24,10 @@ def test_create_and_modify_totp_entry(client):
def modify(idx, **kwargs):
calls["modify"] = (idx, kwargs)
api._pm.entry_manager.add_totp = add_totp
api._pm.entry_manager.modify_entry = modify
api._pm.entry_manager.get_next_index = lambda: 5
api._pm.parent_seed = "seed"
api.app.state.pm.entry_manager.add_totp = add_totp
api.app.state.pm.entry_manager.modify_entry = modify
api.app.state.pm.entry_manager.get_next_index = lambda: 5
api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}"}
res = cl.post(
@@ -77,9 +76,9 @@ def test_create_and_modify_ssh_entry(client):
def modify(idx, **kwargs):
calls["modify"] = (idx, kwargs)
api._pm.entry_manager.add_ssh_key = add_ssh
api._pm.entry_manager.modify_entry = modify
api._pm.parent_seed = "seed"
api.app.state.pm.entry_manager.add_ssh_key = add_ssh
api.app.state.pm.entry_manager.modify_entry = modify
api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}"}
res = cl.post(
@@ -107,7 +106,7 @@ def test_update_entry_error(client):
def modify(*a, **k):
raise ValueError("nope")
api._pm.entry_manager.modify_entry = modify
api.app.state.pm.entry_manager.modify_entry = modify
headers = {"Authorization": f"Bearer {token}"}
res = cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers)
assert res.status_code == 400
@@ -121,7 +120,7 @@ def test_update_config_secret_mode(client):
def set_secret(val):
called["val"] = val
api._pm.config_manager.set_secret_mode_enabled = set_secret
api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret
headers = {"Authorization": f"Bearer {token}"}
res = cl.put(
"/api/v1/config/secret_mode_enabled",
@@ -135,8 +134,8 @@ def test_update_config_secret_mode(client):
def test_totp_export_endpoint(client):
cl, token = client
api._pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]}
api._pm.parent_seed = "seed"
api.app.state.pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]}
api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"}
res = cl.get("/api/v1/totp/export", headers=headers)
assert res.status_code == 200
@@ -145,10 +144,12 @@ def test_totp_export_endpoint(client):
def test_totp_codes_endpoint(client):
cl, token = client
api._pm.entry_manager.list_entries = lambda **kw: [(0, "Email", None, None, False)]
api._pm.entry_manager.get_totp_code = lambda i, s: "123456"
api._pm.entry_manager.get_totp_time_remaining = lambda i: 30
api._pm.parent_seed = "seed"
api.app.state.pm.entry_manager.list_entries = lambda **kw: [
(0, "Email", None, None, False)
]
api.app.state.pm.entry_manager.get_totp_code = lambda i, s: "123456"
api.app.state.pm.entry_manager.get_totp_time_remaining = lambda i: 30
api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"}
res = cl.get("/api/v1/totp", headers=headers)
assert res.status_code == 200
@@ -169,11 +170,11 @@ def test_fingerprint_endpoints(client):
cl, token = client
calls = {}
api._pm.add_new_fingerprint = lambda: calls.setdefault("add", True)
api._pm.fingerprint_manager.remove_fingerprint = lambda fp: calls.setdefault(
"remove", fp
api.app.state.pm.add_new_fingerprint = lambda: calls.setdefault("add", True)
api.app.state.pm.fingerprint_manager.remove_fingerprint = (
lambda fp: calls.setdefault("remove", fp)
)
api._pm.select_fingerprint = lambda fp: calls.setdefault("select", fp)
api.app.state.pm.select_fingerprint = lambda fp: calls.setdefault("select", fp)
headers = {"Authorization": f"Bearer {token}"}
@@ -201,8 +202,10 @@ def test_checksum_endpoints(client):
cl, token = client
calls = {}
api._pm.handle_verify_checksum = lambda: calls.setdefault("verify", True)
api._pm.handle_update_script_checksum = lambda: calls.setdefault("update", True)
api.app.state.pm.handle_verify_checksum = lambda: calls.setdefault("verify", True)
api.app.state.pm.handle_update_script_checksum = lambda: calls.setdefault(
"update", True
)
headers = {"Authorization": f"Bearer {token}"}
@@ -224,9 +227,11 @@ def test_vault_import_via_path(client, tmp_path):
def import_db(path):
called["path"] = path
api._pm.handle_import_database = import_db
api._pm.sync_vault = lambda: called.setdefault("sync", True)
api._pm.encryption_manager = SimpleNamespace(resolve_relative_path=lambda p: p)
api.app.state.pm.handle_import_database = import_db
api.app.state.pm.sync_vault = lambda: called.setdefault("sync", True)
api.app.state.pm.encryption_manager = SimpleNamespace(
resolve_relative_path=lambda p: p
)
file_path = tmp_path / "b.json.enc"
file_path.write_text("{}")
@@ -249,8 +254,8 @@ def test_vault_import_via_upload(client, tmp_path):
def import_db(path):
called["path"] = path
api._pm.handle_import_database = import_db
api._pm.sync_vault = lambda: called.setdefault("sync", True)
api.app.state.pm.handle_import_database = import_db
api.app.state.pm.sync_vault = lambda: called.setdefault("sync", True)
file_path = tmp_path / "c.json"
file_path.write_text("{}")
@@ -269,9 +274,11 @@ def test_vault_import_via_upload(client, tmp_path):
def test_vault_import_invalid_extension(client):
cl, token = client
api._pm.handle_import_database = lambda path: None
api._pm.sync_vault = lambda: None
api._pm.encryption_manager = SimpleNamespace(resolve_relative_path=lambda p: p)
api.app.state.pm.handle_import_database = lambda path: None
api.app.state.pm.sync_vault = lambda: None
api.app.state.pm.encryption_manager = SimpleNamespace(
resolve_relative_path=lambda p: p
)
headers = {"Authorization": f"Bearer {token}"}
res = cl.post(
@@ -285,9 +292,9 @@ def test_vault_import_invalid_extension(client):
def test_vault_import_path_traversal_blocked(client, tmp_path):
cl, token = client
key = base64.urlsafe_b64encode(os.urandom(32))
api._pm.encryption_manager = EncryptionManager(key, tmp_path)
api._pm.handle_import_database = lambda path: None
api._pm.sync_vault = lambda: None
api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path)
api.app.state.pm.handle_import_database = lambda path: None
api.app.state.pm.sync_vault = lambda: None
headers = {"Authorization": f"Bearer {token}"}
res = cl.post(
@@ -304,20 +311,22 @@ def test_vault_lock_endpoint(client):
def lock():
called["locked"] = True
api._pm.locked = True
api.app.state.pm.locked = True
api._pm.lock_vault = lock
api._pm.locked = False
api.app.state.pm.lock_vault = lock
api.app.state.pm.locked = False
headers = {"Authorization": f"Bearer {token}"}
res = cl.post("/api/v1/vault/lock", headers=headers)
assert res.status_code == 200
assert res.json() == {"status": "locked"}
assert called.get("locked") is True
assert api._pm.locked is True
api._pm.unlock_vault = lambda pw: setattr(api._pm, "locked", False)
api._pm.unlock_vault("pw")
assert api._pm.locked is False
assert api.app.state.pm.locked is True
api.app.state.pm.unlock_vault = lambda pw: setattr(
api.app.state.pm, "locked", False
)
api.app.state.pm.unlock_vault("pw")
assert api.app.state.pm.locked is False
def test_secret_mode_endpoint(client):
@@ -330,8 +339,8 @@ def test_secret_mode_endpoint(client):
def set_delay(val):
called.setdefault("delay", val)
api._pm.config_manager.set_secret_mode_enabled = set_secret
api._pm.config_manager.set_clipboard_clear_delay = set_delay
api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret
api.app.state.pm.config_manager.set_clipboard_clear_delay = set_delay
headers = {"Authorization": f"Bearer {token}"}
res = cl.post(
@@ -350,7 +359,7 @@ def test_vault_export_endpoint(client, tmp_path):
out = tmp_path / "out.json"
out.write_text("data")
api._pm.handle_export_database = lambda: out
api.app.state.pm.handle_export_database = lambda: out
headers = {
"Authorization": f"Bearer {token}",
@@ -366,9 +375,9 @@ def test_vault_export_endpoint(client, tmp_path):
def test_backup_parent_seed_endpoint(client, tmp_path):
cl, token = client
api._pm.parent_seed = "seed"
api.app.state.pm.parent_seed = "seed"
called = {}
api._pm.encryption_manager = SimpleNamespace(
api.app.state.pm.encryption_manager = SimpleNamespace(
encrypt_and_save_file=lambda data, path: called.setdefault("path", path),
resolve_relative_path=lambda p: p,
)
@@ -396,9 +405,9 @@ def test_backup_parent_seed_endpoint(client, tmp_path):
def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
cl, token = client
api._pm.parent_seed = "seed"
api.app.state.pm.parent_seed = "seed"
key = base64.urlsafe_b64encode(os.urandom(32))
api._pm.encryption_manager = EncryptionManager(key, tmp_path)
api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path)
headers = {
"Authorization": f"Bearer {token}",
"X-SeedPass-Password": "pw",
@@ -424,8 +433,8 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
def set_relays(new, require_pin=False):
called["set"] = new
api._pm.config_manager.load_config = load_config
api._pm.config_manager.set_relays = set_relays
api.app.state.pm.config_manager.load_config = load_config
api.app.state.pm.config_manager.set_relays = set_relays
monkeypatch.setattr(
NostrClient,
"initialize_client_pool",
@@ -434,8 +443,8 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
monkeypatch.setattr(
nostr_client, "close_client_pool", lambda: called.setdefault("close", True)
)
api._pm.nostr_client = nostr_client
api._pm.nostr_client.relays = relays.copy()
api.app.state.pm.nostr_client = nostr_client
api.app.state.pm.nostr_client.relays = relays.copy()
headers = {"Authorization": f"Bearer {token}"}
@@ -447,7 +456,7 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
assert res.status_code == 200
assert called["set"] == ["wss://a", "wss://b", "wss://c"]
api._pm.config_manager.load_config = lambda require_pin=False: {
api.app.state.pm.config_manager.load_config = lambda require_pin=False: {
"relays": ["wss://a", "wss://b", "wss://c"]
}
res = cl.delete("/api/v1/relays/2", headers=headers)
@@ -457,7 +466,7 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
res = cl.post("/api/v1/relays/reset", headers=headers)
assert res.status_code == 200
assert called.get("init") is True
assert api._pm.nostr_client.relays == list(DEFAULT_RELAYS)
assert api.app.state.pm.nostr_client.relays == list(DEFAULT_RELAYS)
def test_generate_password_no_special_chars(client):
@@ -471,8 +480,10 @@ def test_generate_password_no_special_chars(client):
def derive_entropy(self, index: int, bytes_len: int, app_no: int = 32) -> bytes:
return bytes(range(bytes_len))
api._pm.password_generator = PasswordGenerator(DummyEnc(), "seed", DummyBIP85())
api._pm.parent_seed = "seed"
api.app.state.pm.password_generator = PasswordGenerator(
DummyEnc(), "seed", DummyBIP85()
)
api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}"}
res = cl.post(
@@ -496,8 +507,10 @@ def test_generate_password_allowed_chars(client):
def derive_entropy(self, index: int, bytes_len: int, app_no: int = 32) -> bytes:
return bytes((index + i) % 256 for i in range(bytes_len))
api._pm.password_generator = PasswordGenerator(DummyEnc(), "seed", DummyBIP85())
api._pm.parent_seed = "seed"
api.app.state.pm.password_generator = PasswordGenerator(
DummyEnc(), "seed", DummyBIP85()
)
api.app.state.pm.parent_seed = "seed"
headers = {"Authorization": f"Bearer {token}"}
allowed = "@$"

View File

@@ -6,40 +6,42 @@ import seedpass.api as api
def test_notifications_endpoint(client):
cl, token = client
api._pm.notifications = queue.Queue()
api._pm.notifications.put(SimpleNamespace(message="m1", level="INFO"))
api._pm.notifications.put(SimpleNamespace(message="m2", level="WARNING"))
api.app.state.pm.notifications = queue.Queue()
api.app.state.pm.notifications.put(SimpleNamespace(message="m1", level="INFO"))
api.app.state.pm.notifications.put(SimpleNamespace(message="m2", level="WARNING"))
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
assert res.status_code == 200
assert res.json() == [
{"level": "INFO", "message": "m1"},
{"level": "WARNING", "message": "m2"},
]
assert api._pm.notifications.empty()
assert api.app.state.pm.notifications.empty()
def test_notifications_endpoint_clears_queue(client):
cl, token = client
api._pm.notifications = queue.Queue()
api._pm.notifications.put(SimpleNamespace(message="hi", level="INFO"))
api.app.state.pm.notifications = queue.Queue()
api.app.state.pm.notifications.put(SimpleNamespace(message="hi", level="INFO"))
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
assert res.status_code == 200
assert res.json() == [{"level": "INFO", "message": "hi"}]
assert api._pm.notifications.empty()
assert api.app.state.pm.notifications.empty()
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
assert res.json() == []
def test_notifications_endpoint_does_not_clear_current(client):
cl, token = client
api._pm.notifications = queue.Queue()
api.app.state.pm.notifications = queue.Queue()
msg = SimpleNamespace(message="keep", level="INFO")
api._pm.notifications.put(msg)
api._pm._current_notification = msg
api._pm.get_current_notification = lambda: api._pm._current_notification
api.app.state.pm.notifications.put(msg)
api.app.state.pm._current_notification = msg
api.app.state.pm.get_current_notification = (
lambda: api.app.state.pm._current_notification
)
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
assert res.status_code == 200
assert res.json() == [{"level": "INFO", "message": "keep"}]
assert api._pm.notifications.empty()
assert api._pm.get_current_notification() is msg
assert api.app.state.pm.notifications.empty()
assert api.app.state.pm.get_current_notification() is msg

View File

@@ -7,7 +7,7 @@ def test_profile_stats_endpoint(client):
# monkeypatch set _pm.get_profile_stats after client fixture started
import seedpass.api as api
api._pm.get_profile_stats = lambda: stats
api.app.state.pm.get_profile_stats = lambda: stats
res = cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"})
assert res.status_code == 200
assert res.json() == stats