mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-08 07:18:47 +00:00
refactor: move api state to app
This commit is contained in:
@@ -36,35 +36,39 @@ _RATE_LIMIT_STR = f"{_RATE_LIMIT}/{_RATE_WINDOW} seconds"
|
||||
limiter = Limiter(key_func=get_remote_address, default_limits=[_RATE_LIMIT_STR])
|
||||
app = FastAPI()
|
||||
|
||||
_pm: Optional[PasswordManager] = None
|
||||
_token: str = ""
|
||||
_jwt_secret: str = ""
|
||||
|
||||
def _get_pm(request: Request) -> PasswordManager:
|
||||
pm = getattr(request.app.state, "pm", None)
|
||||
assert pm is not None
|
||||
return pm
|
||||
|
||||
|
||||
def _check_token(auth: str | None) -> None:
|
||||
def _check_token(request: Request, auth: str | None) -> None:
|
||||
if auth is None or not auth.startswith("Bearer "):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
token = auth.split(" ", 1)[1]
|
||||
jwt_secret = getattr(request.app.state, "jwt_secret", "")
|
||||
token_hash = getattr(request.app.state, "token_hash", "")
|
||||
try:
|
||||
jwt.decode(token, _jwt_secret, algorithms=["HS256"])
|
||||
jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(status_code=401, detail="Token expired")
|
||||
except jwt.InvalidTokenError:
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), _token):
|
||||
if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), token_hash):
|
||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||
|
||||
|
||||
def _reload_relays(relays: list[str]) -> None:
|
||||
def _reload_relays(request: Request, relays: list[str]) -> None:
|
||||
"""Reload the Nostr client with a new relay list."""
|
||||
assert _pm is not None
|
||||
pm = _get_pm(request)
|
||||
try:
|
||||
_pm.nostr_client.close_client_pool()
|
||||
pm.nostr_client.close_client_pool()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
_pm.nostr_client.relays = relays
|
||||
_pm.nostr_client.initialize_client_pool()
|
||||
pm.nostr_client.relays = relays
|
||||
pm.nostr_client.initialize_client_pool()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -77,15 +81,15 @@ def start_server(fingerprint: str | None = None) -> str:
|
||||
fingerprint:
|
||||
Optional seed profile fingerprint to select before starting the server.
|
||||
"""
|
||||
global _pm, _token, _jwt_secret
|
||||
if fingerprint is None:
|
||||
_pm = PasswordManager()
|
||||
pm = PasswordManager()
|
||||
else:
|
||||
_pm = PasswordManager(fingerprint=fingerprint)
|
||||
_jwt_secret = secrets.token_urlsafe(32)
|
||||
pm = PasswordManager(fingerprint=fingerprint)
|
||||
app.state.pm = pm
|
||||
app.state.jwt_secret = secrets.token_urlsafe(32)
|
||||
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
|
||||
raw_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
|
||||
_token = hashlib.sha256(raw_token.encode()).hexdigest()
|
||||
raw_token = jwt.encode(payload, app.state.jwt_secret, algorithm="HS256")
|
||||
app.state.token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
|
||||
if not getattr(app.state, "limiter", None):
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
@@ -105,30 +109,32 @@ def start_server(fingerprint: str | None = None) -> str:
|
||||
return raw_token
|
||||
|
||||
|
||||
def _require_password(password: str | None) -> None:
|
||||
assert _pm is not None
|
||||
if password is None or not _pm.verify_password(password):
|
||||
def _require_password(request: Request, password: str | None) -> None:
|
||||
pm = _get_pm(request)
|
||||
if password is None or not pm.verify_password(password):
|
||||
raise HTTPException(status_code=401, detail="Invalid password")
|
||||
|
||||
|
||||
def _validate_encryption_path(path: Path) -> Path:
|
||||
def _validate_encryption_path(request: Request, path: Path) -> Path:
|
||||
"""Validate and normalize ``path`` within the active fingerprint directory.
|
||||
|
||||
Returns the resolved absolute path if validation succeeds.
|
||||
"""
|
||||
|
||||
assert _pm is not None
|
||||
pm = _get_pm(request)
|
||||
try:
|
||||
return _pm.encryption_manager.resolve_relative_path(path)
|
||||
return pm.encryption_manager.resolve_relative_path(path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@app.get("/api/v1/entry")
|
||||
def search_entry(query: str, authorization: str | None = Header(None)) -> List[Any]:
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
results = _pm.entry_manager.search_entries(query)
|
||||
def search_entry(
|
||||
request: Request, query: str, authorization: str | None = Header(None)
|
||||
) -> List[Any]:
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
results = pm.entry_manager.search_entries(query)
|
||||
return [
|
||||
{
|
||||
"id": idx,
|
||||
@@ -144,14 +150,15 @@ def search_entry(query: str, authorization: str | None = Header(None)) -> List[A
|
||||
|
||||
@app.get("/api/v1/entry/{entry_id}")
|
||||
def get_entry(
|
||||
request: Request,
|
||||
entry_id: int,
|
||||
authorization: str | None = Header(None),
|
||||
password: str | None = Header(None, alias="X-SeedPass-Password"),
|
||||
) -> Any:
|
||||
_check_token(authorization)
|
||||
_require_password(password)
|
||||
assert _pm is not None
|
||||
entry = _pm.entry_manager.retrieve_entry(entry_id)
|
||||
_check_token(request, authorization)
|
||||
_require_password(request, password)
|
||||
pm = _get_pm(request)
|
||||
entry = pm.entry_manager.retrieve_entry(entry_id)
|
||||
if entry is None:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return entry
|
||||
@@ -159,6 +166,7 @@ def get_entry(
|
||||
|
||||
@app.post("/api/v1/entry")
|
||||
def create_entry(
|
||||
request: Request,
|
||||
entry: dict,
|
||||
authorization: str | None = Header(None),
|
||||
) -> dict[str, Any]:
|
||||
@@ -168,8 +176,8 @@ def create_entry(
|
||||
on, the corresponding entry type is created. When omitted or set to
|
||||
``password`` the behaviour matches the legacy password-entry API.
|
||||
"""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
|
||||
etype = (entry.get("type") or entry.get("kind") or "password").lower()
|
||||
|
||||
@@ -186,7 +194,7 @@ def create_entry(
|
||||
]
|
||||
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
|
||||
|
||||
index = _pm.entry_manager.add_entry(
|
||||
index = pm.entry_manager.add_entry(
|
||||
entry.get("label"),
|
||||
int(entry.get("length", 12)),
|
||||
entry.get("username"),
|
||||
@@ -196,11 +204,11 @@ def create_entry(
|
||||
return {"id": index}
|
||||
|
||||
if etype == "totp":
|
||||
index = _pm.entry_manager.get_next_index()
|
||||
index = pm.entry_manager.get_next_index()
|
||||
|
||||
uri = _pm.entry_manager.add_totp(
|
||||
uri = pm.entry_manager.add_totp(
|
||||
entry.get("label"),
|
||||
_pm.parent_seed,
|
||||
pm.parent_seed,
|
||||
secret=entry.get("secret"),
|
||||
index=entry.get("index"),
|
||||
period=int(entry.get("period", 30)),
|
||||
@@ -211,9 +219,9 @@ def create_entry(
|
||||
return {"id": index, "uri": uri}
|
||||
|
||||
if etype == "ssh":
|
||||
index = _pm.entry_manager.add_ssh_key(
|
||||
index = pm.entry_manager.add_ssh_key(
|
||||
entry.get("label"),
|
||||
_pm.parent_seed,
|
||||
pm.parent_seed,
|
||||
index=entry.get("index"),
|
||||
notes=entry.get("notes", ""),
|
||||
archived=entry.get("archived", False),
|
||||
@@ -221,9 +229,9 @@ def create_entry(
|
||||
return {"id": index}
|
||||
|
||||
if etype == "pgp":
|
||||
index = _pm.entry_manager.add_pgp_key(
|
||||
index = pm.entry_manager.add_pgp_key(
|
||||
entry.get("label"),
|
||||
_pm.parent_seed,
|
||||
pm.parent_seed,
|
||||
index=entry.get("index"),
|
||||
key_type=entry.get("key_type", "ed25519"),
|
||||
user_id=entry.get("user_id", ""),
|
||||
@@ -233,9 +241,9 @@ def create_entry(
|
||||
return {"id": index}
|
||||
|
||||
if etype == "nostr":
|
||||
index = _pm.entry_manager.add_nostr_key(
|
||||
index = pm.entry_manager.add_nostr_key(
|
||||
entry.get("label"),
|
||||
_pm.parent_seed,
|
||||
pm.parent_seed,
|
||||
index=entry.get("index"),
|
||||
notes=entry.get("notes", ""),
|
||||
archived=entry.get("archived", False),
|
||||
@@ -243,7 +251,7 @@ def create_entry(
|
||||
return {"id": index}
|
||||
|
||||
if etype == "key_value":
|
||||
index = _pm.entry_manager.add_key_value(
|
||||
index = pm.entry_manager.add_key_value(
|
||||
entry.get("label"),
|
||||
entry.get("key"),
|
||||
entry.get("value"),
|
||||
@@ -253,13 +261,13 @@ def create_entry(
|
||||
|
||||
if etype in {"seed", "managed_account"}:
|
||||
func = (
|
||||
_pm.entry_manager.add_seed
|
||||
pm.entry_manager.add_seed
|
||||
if etype == "seed"
|
||||
else _pm.entry_manager.add_managed_account
|
||||
else pm.entry_manager.add_managed_account
|
||||
)
|
||||
index = func(
|
||||
entry.get("label"),
|
||||
_pm.parent_seed,
|
||||
pm.parent_seed,
|
||||
index=entry.get("index"),
|
||||
notes=entry.get("notes", ""),
|
||||
)
|
||||
@@ -270,6 +278,7 @@ def create_entry(
|
||||
|
||||
@app.put("/api/v1/entry/{entry_id}")
|
||||
def update_entry(
|
||||
request: Request,
|
||||
entry_id: int,
|
||||
entry: dict,
|
||||
authorization: str | None = Header(None),
|
||||
@@ -279,10 +288,10 @@ def update_entry(
|
||||
Additional fields like ``period``, ``digits`` and ``value`` are forwarded for
|
||||
specialized entry types (e.g. TOTP or key/value entries).
|
||||
"""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
try:
|
||||
_pm.entry_manager.modify_entry(
|
||||
pm.entry_manager.modify_entry(
|
||||
entry_id,
|
||||
username=entry.get("username"),
|
||||
url=entry.get("url"),
|
||||
@@ -300,31 +309,33 @@ def update_entry(
|
||||
|
||||
@app.post("/api/v1/entry/{entry_id}/archive")
|
||||
def archive_entry(
|
||||
entry_id: int, authorization: str | None = Header(None)
|
||||
request: Request, entry_id: int, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Archive an entry."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.entry_manager.archive_entry(entry_id)
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
pm.entry_manager.archive_entry(entry_id)
|
||||
return {"status": "archived"}
|
||||
|
||||
|
||||
@app.post("/api/v1/entry/{entry_id}/unarchive")
|
||||
def unarchive_entry(
|
||||
entry_id: int, authorization: str | None = Header(None)
|
||||
request: Request, entry_id: int, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Restore an archived entry."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.entry_manager.restore_entry(entry_id)
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
pm.entry_manager.restore_entry(entry_id)
|
||||
return {"status": "active"}
|
||||
|
||||
|
||||
@app.get("/api/v1/config/{key}")
|
||||
def get_config(key: str, authorization: str | None = Header(None)) -> Any:
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
value = _pm.config_manager.load_config(require_pin=False).get(key)
|
||||
def get_config(
|
||||
request: Request, key: str, authorization: str | None = Header(None)
|
||||
) -> Any:
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
value = pm.config_manager.load_config(require_pin=False).get(key)
|
||||
|
||||
if value is None:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
@@ -333,12 +344,15 @@ def get_config(key: str, authorization: str | None = Header(None)) -> Any:
|
||||
|
||||
@app.put("/api/v1/config/{key}")
|
||||
def update_config(
|
||||
key: str, data: dict, authorization: str | None = Header(None)
|
||||
request: Request,
|
||||
key: str,
|
||||
data: dict,
|
||||
authorization: str | None = Header(None),
|
||||
) -> dict[str, str]:
|
||||
"""Update a configuration setting."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
cfg = _pm.config_manager
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
cfg = pm.config_manager
|
||||
mapping = {
|
||||
"relays": lambda v: cfg.set_relays(v, require_pin=False),
|
||||
"pin": cfg.set_pin,
|
||||
@@ -364,96 +378,102 @@ def update_config(
|
||||
|
||||
@app.post("/api/v1/secret-mode")
|
||||
def set_secret_mode(
|
||||
data: dict, authorization: str | None = Header(None)
|
||||
request: Request, data: dict, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Enable/disable secret mode and set the clipboard delay."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
enabled = data.get("enabled")
|
||||
|
||||
delay = data.get("delay")
|
||||
|
||||
if enabled is None or delay is None:
|
||||
raise HTTPException(status_code=400, detail="Missing fields")
|
||||
cfg = _pm.config_manager
|
||||
cfg = pm.config_manager
|
||||
cfg.set_secret_mode_enabled(bool(enabled))
|
||||
cfg.set_clipboard_clear_delay(int(delay))
|
||||
_pm.secret_mode_enabled = bool(enabled)
|
||||
_pm.clipboard_clear_delay = int(delay)
|
||||
pm.secret_mode_enabled = bool(enabled)
|
||||
pm.clipboard_clear_delay = int(delay)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/api/v1/fingerprint")
|
||||
def list_fingerprints(authorization: str | None = Header(None)) -> List[str]:
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
return _pm.fingerprint_manager.list_fingerprints()
|
||||
def list_fingerprints(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> List[str]:
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
return pm.fingerprint_manager.list_fingerprints()
|
||||
|
||||
|
||||
@app.post("/api/v1/fingerprint")
|
||||
def add_fingerprint(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
def add_fingerprint(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Create a new seed profile."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.add_new_fingerprint()
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
pm.add_new_fingerprint()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.delete("/api/v1/fingerprint/{fingerprint}")
|
||||
def remove_fingerprint(
|
||||
fingerprint: str, authorization: str | None = Header(None)
|
||||
request: Request, fingerprint: str, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Remove a seed profile."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.fingerprint_manager.remove_fingerprint(fingerprint)
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
pm.fingerprint_manager.remove_fingerprint(fingerprint)
|
||||
return {"status": "deleted"}
|
||||
|
||||
|
||||
@app.post("/api/v1/fingerprint/select")
|
||||
def select_fingerprint(
|
||||
data: dict, authorization: str | None = Header(None)
|
||||
request: Request, data: dict, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Switch the active seed profile."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
fp = data.get("fingerprint")
|
||||
|
||||
if not fp:
|
||||
raise HTTPException(status_code=400, detail="Missing fingerprint")
|
||||
_pm.select_fingerprint(fp)
|
||||
pm.select_fingerprint(fp)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.get("/api/v1/totp/export")
|
||||
def export_totp(
|
||||
request: Request,
|
||||
authorization: str | None = Header(None),
|
||||
password: str | None = Header(None, alias="X-SeedPass-Password"),
|
||||
) -> dict:
|
||||
"""Return all stored TOTP entries in JSON format."""
|
||||
_check_token(authorization)
|
||||
_require_password(password)
|
||||
assert _pm is not None
|
||||
return _pm.entry_manager.export_totp_entries(_pm.parent_seed)
|
||||
_check_token(request, authorization)
|
||||
_require_password(request, password)
|
||||
pm = _get_pm(request)
|
||||
return pm.entry_manager.export_totp_entries(pm.parent_seed)
|
||||
|
||||
|
||||
@app.get("/api/v1/totp")
|
||||
def get_totp_codes(
|
||||
request: Request,
|
||||
authorization: str | None = Header(None),
|
||||
password: str | None = Header(None, alias="X-SeedPass-Password"),
|
||||
) -> dict:
|
||||
"""Return active TOTP codes with remaining seconds."""
|
||||
_check_token(authorization)
|
||||
_require_password(password)
|
||||
assert _pm is not None
|
||||
entries = _pm.entry_manager.list_entries(
|
||||
_check_token(request, authorization)
|
||||
_require_password(request, password)
|
||||
pm = _get_pm(request)
|
||||
entries = pm.entry_manager.list_entries(
|
||||
filter_kind=EntryType.TOTP.value, include_archived=False
|
||||
)
|
||||
codes = []
|
||||
for idx, label, _u, _url, _arch in entries:
|
||||
code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed)
|
||||
code = pm.entry_manager.get_totp_code(idx, pm.parent_seed)
|
||||
|
||||
rem = _pm.entry_manager.get_totp_time_remaining(idx)
|
||||
rem = pm.entry_manager.get_totp_time_remaining(idx)
|
||||
|
||||
codes.append(
|
||||
{"id": idx, "label": label, "code": code, "seconds_remaining": rem}
|
||||
@@ -462,23 +482,26 @@ def get_totp_codes(
|
||||
|
||||
|
||||
@app.get("/api/v1/stats")
|
||||
def get_profile_stats(authorization: str | None = Header(None)) -> dict:
|
||||
def get_profile_stats(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> dict:
|
||||
"""Return statistics about the active seed profile."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
return _pm.get_profile_stats()
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
return pm.get_profile_stats()
|
||||
|
||||
|
||||
@app.get("/api/v1/notifications")
|
||||
def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
|
||||
def get_notifications(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> List[dict]:
|
||||
"""Return and clear queued notifications."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
notes = []
|
||||
while True:
|
||||
try:
|
||||
note = _pm.notifications.get_nowait()
|
||||
|
||||
note = pm.notifications.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
notes.append({"level": note.level, "message": note.message})
|
||||
@@ -486,47 +509,51 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
|
||||
|
||||
|
||||
@app.get("/api/v1/nostr/pubkey")
|
||||
def get_nostr_pubkey(authorization: str | None = Header(None)) -> Any:
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
return {"npub": _pm.nostr_client.key_manager.get_npub()}
|
||||
def get_nostr_pubkey(request: Request, authorization: str | None = Header(None)) -> Any:
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
return {"npub": pm.nostr_client.key_manager.get_npub()}
|
||||
|
||||
|
||||
@app.get("/api/v1/relays")
|
||||
def list_relays(authorization: str | None = Header(None)) -> dict:
|
||||
def list_relays(request: Request, authorization: str | None = Header(None)) -> dict:
|
||||
"""Return the configured Nostr relays."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
cfg = _pm.config_manager.load_config(require_pin=False)
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
cfg = pm.config_manager.load_config(require_pin=False)
|
||||
return {"relays": cfg.get("relays", [])}
|
||||
|
||||
|
||||
@app.post("/api/v1/relays")
|
||||
def add_relay(data: dict, authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
def add_relay(
|
||||
request: Request, data: dict, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Add a relay URL to the configuration."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
url = data.get("url")
|
||||
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Missing url")
|
||||
cfg = _pm.config_manager.load_config(require_pin=False)
|
||||
cfg = pm.config_manager.load_config(require_pin=False)
|
||||
relays = cfg.get("relays", [])
|
||||
|
||||
if url in relays:
|
||||
raise HTTPException(status_code=400, detail="Relay already present")
|
||||
relays.append(url)
|
||||
_pm.config_manager.set_relays(relays, require_pin=False)
|
||||
_reload_relays(relays)
|
||||
pm.config_manager.set_relays(relays, require_pin=False)
|
||||
_reload_relays(request, relays)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.delete("/api/v1/relays/{idx}")
|
||||
def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
def remove_relay(
|
||||
request: Request, idx: int, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Remove a relay by its index (1-based)."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
cfg = _pm.config_manager.load_config(require_pin=False)
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
cfg = pm.config_manager.load_config(require_pin=False)
|
||||
relays = cfg.get("relays", [])
|
||||
|
||||
if not (1 <= idx <= len(relays)):
|
||||
@@ -534,52 +561,59 @@ def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str
|
||||
if len(relays) == 1:
|
||||
raise HTTPException(status_code=400, detail="At least one relay required")
|
||||
relays.pop(idx - 1)
|
||||
_pm.config_manager.set_relays(relays, require_pin=False)
|
||||
_reload_relays(relays)
|
||||
pm.config_manager.set_relays(relays, require_pin=False)
|
||||
_reload_relays(request, relays)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/api/v1/relays/reset")
|
||||
def reset_relays(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
def reset_relays(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Reset relay list to defaults."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
from nostr.client import DEFAULT_RELAYS
|
||||
|
||||
relays = list(DEFAULT_RELAYS)
|
||||
_pm.config_manager.set_relays(relays, require_pin=False)
|
||||
_reload_relays(relays)
|
||||
pm.config_manager.set_relays(relays, require_pin=False)
|
||||
_reload_relays(request, relays)
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/api/v1/checksum/verify")
|
||||
def verify_checksum(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
def verify_checksum(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Verify the SeedPass script checksum."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.handle_verify_checksum()
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
pm.handle_verify_checksum()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/api/v1/checksum/update")
|
||||
def update_checksum(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
def update_checksum(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Regenerate the script checksum file."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.handle_update_script_checksum()
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
pm.handle_update_script_checksum()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/api/v1/vault/export")
|
||||
def export_vault(
|
||||
request: Request,
|
||||
authorization: str | None = Header(None),
|
||||
password: str | None = Header(None, alias="X-SeedPass-Password"),
|
||||
):
|
||||
"""Export the vault and return the encrypted file."""
|
||||
_check_token(authorization)
|
||||
_require_password(password)
|
||||
assert _pm is not None
|
||||
path = _pm.handle_export_database()
|
||||
_check_token(request, authorization)
|
||||
_require_password(request, password)
|
||||
pm = _get_pm(request)
|
||||
path = pm.handle_export_database()
|
||||
if path is None:
|
||||
raise HTTPException(status_code=500, detail="Export failed")
|
||||
data = Path(path).read_bytes()
|
||||
@@ -591,8 +625,8 @@ async def import_vault(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Import a vault backup from a file upload or a server path."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
|
||||
ctype = request.headers.get("content-type", "")
|
||||
|
||||
@@ -607,7 +641,7 @@ async def import_vault(
|
||||
tmp.write(data)
|
||||
tmp_path = Path(tmp.name)
|
||||
try:
|
||||
_pm.handle_import_database(tmp_path)
|
||||
pm.handle_import_database(tmp_path)
|
||||
finally:
|
||||
os.unlink(tmp_path)
|
||||
else:
|
||||
@@ -617,28 +651,29 @@ async def import_vault(
|
||||
if not path_str:
|
||||
raise HTTPException(status_code=400, detail="Missing file or path")
|
||||
|
||||
path = _validate_encryption_path(Path(path_str))
|
||||
path = _validate_encryption_path(request, Path(path_str))
|
||||
if not str(path).endswith(".json.enc"):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Selected file must be a '.json.enc' backup",
|
||||
)
|
||||
|
||||
_pm.handle_import_database(path)
|
||||
_pm.sync_vault()
|
||||
pm.handle_import_database(path)
|
||||
pm.sync_vault()
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/api/v1/vault/backup-parent-seed")
|
||||
def backup_parent_seed(
|
||||
request: Request,
|
||||
data: dict,
|
||||
authorization: str | None = Header(None),
|
||||
password: str | None = Header(None, alias="X-SeedPass-Password"),
|
||||
) -> dict[str, str]:
|
||||
"""Create an encrypted backup of the parent seed after confirmation."""
|
||||
_check_token(authorization)
|
||||
_require_password(password)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
_require_password(request, password)
|
||||
pm = _get_pm(request)
|
||||
|
||||
if not data.get("confirm"):
|
||||
|
||||
@@ -649,30 +684,30 @@ def backup_parent_seed(
|
||||
if not path_str:
|
||||
raise HTTPException(status_code=400, detail="Missing path")
|
||||
path = Path(path_str)
|
||||
_validate_encryption_path(path)
|
||||
_pm.encryption_manager.encrypt_and_save_file(_pm.parent_seed.encode("utf-8"), path)
|
||||
_validate_encryption_path(request, path)
|
||||
pm.encryption_manager.encrypt_and_save_file(pm.parent_seed.encode("utf-8"), path)
|
||||
return {"status": "saved", "path": str(path)}
|
||||
|
||||
|
||||
@app.post("/api/v1/change-password")
|
||||
def change_password(
|
||||
data: dict, authorization: str | None = Header(None)
|
||||
request: Request, data: dict, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Change the master password for the active profile."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.change_password(data.get("old", ""), data.get("new", ""))
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
pm.change_password(data.get("old", ""), data.get("new", ""))
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@app.post("/api/v1/password")
|
||||
def generate_password(
|
||||
data: dict, authorization: str | None = Header(None)
|
||||
request: Request, data: dict, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Generate a password using optional policy overrides."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
length = int(data.get("length", 12))
|
||||
|
||||
policy_keys = [
|
||||
@@ -687,23 +722,27 @@ def generate_password(
|
||||
]
|
||||
kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None}
|
||||
|
||||
util = UtilityService(_pm)
|
||||
util = UtilityService(pm)
|
||||
password = util.generate_password(length, **kwargs)
|
||||
return {"password": password}
|
||||
|
||||
|
||||
@app.post("/api/v1/vault/lock")
|
||||
def lock_vault(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
def lock_vault(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
"""Lock the vault and clear sensitive data from memory."""
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.lock_vault()
|
||||
_check_token(request, authorization)
|
||||
pm = _get_pm(request)
|
||||
pm.lock_vault()
|
||||
return {"status": "locked"}
|
||||
|
||||
|
||||
@app.post("/api/v1/shutdown")
|
||||
async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
_check_token(authorization)
|
||||
async def shutdown_server(
|
||||
request: Request, authorization: str | None = Header(None)
|
||||
) -> dict[str, str]:
|
||||
_check_token(request, authorization)
|
||||
asyncio.get_event_loop().call_soon(sys.exit, 0)
|
||||
|
||||
return {"status": "shutting down"}
|
||||
|
@@ -25,7 +25,7 @@ def api_stop(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> N
|
||||
try:
|
||||
requests.post(
|
||||
f"http://{host}:{port}/api/v1/shutdown",
|
||||
headers={"Authorization": f"Bearer {api_module._token}"},
|
||||
headers={"Authorization": f"Bearer {api_module.app.state.token_hash}"},
|
||||
timeout=2,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - best effort
|
||||
|
@@ -51,8 +51,8 @@ def client(monkeypatch):
|
||||
|
||||
def test_token_hashed(client):
|
||||
_, token = client
|
||||
assert api._token != token
|
||||
assert api._token == hashlib.sha256(token.encode()).hexdigest()
|
||||
assert api.app.state.token_hash != token
|
||||
assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest()
|
||||
|
||||
|
||||
def test_cors_and_auth(client):
|
||||
@@ -158,7 +158,7 @@ def test_update_config(client):
|
||||
def set_timeout(val):
|
||||
called["val"] = val
|
||||
|
||||
api._pm.config_manager.set_inactivity_timeout = set_timeout
|
||||
api.app.state.pm.config_manager.set_inactivity_timeout = set_timeout
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
res = cl.put(
|
||||
"/api/v1/config/inactivity_timeout",
|
||||
@@ -174,8 +174,9 @@ def test_update_config(client):
|
||||
def test_update_config_quick_unlock(client):
|
||||
cl, token = client
|
||||
called = {}
|
||||
|
||||
api._pm.config_manager.set_quick_unlock = lambda v: called.setdefault("val", v)
|
||||
api.app.state.pm.config_manager.set_quick_unlock = lambda v: called.setdefault(
|
||||
"val", v
|
||||
)
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
res = cl.put(
|
||||
"/api/v1/config/quick_unlock",
|
||||
@@ -190,8 +191,7 @@ def test_update_config_quick_unlock(client):
|
||||
def test_change_password_route(client):
|
||||
cl, token = client
|
||||
called = {}
|
||||
|
||||
api._pm.change_password = lambda o, n: called.setdefault("called", (o, n))
|
||||
api.app.state.pm.change_password = lambda o, n: called.setdefault("called", (o, n))
|
||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||
res = cl.post(
|
||||
"/api/v1/change-password",
|
||||
|
@@ -3,7 +3,6 @@ from pathlib import Path
|
||||
import os
|
||||
import base64
|
||||
import pytest
|
||||
from types import SimpleNamespace
|
||||
|
||||
from seedpass import api
|
||||
from test_api import client
|
||||
@@ -25,10 +24,10 @@ def test_create_and_modify_totp_entry(client):
|
||||
def modify(idx, **kwargs):
|
||||
calls["modify"] = (idx, kwargs)
|
||||
|
||||
api._pm.entry_manager.add_totp = add_totp
|
||||
api._pm.entry_manager.modify_entry = modify
|
||||
api._pm.entry_manager.get_next_index = lambda: 5
|
||||
api._pm.parent_seed = "seed"
|
||||
api.app.state.pm.entry_manager.add_totp = add_totp
|
||||
api.app.state.pm.entry_manager.modify_entry = modify
|
||||
api.app.state.pm.entry_manager.get_next_index = lambda: 5
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
@@ -77,9 +76,9 @@ def test_create_and_modify_ssh_entry(client):
|
||||
def modify(idx, **kwargs):
|
||||
calls["modify"] = (idx, kwargs)
|
||||
|
||||
api._pm.entry_manager.add_ssh_key = add_ssh
|
||||
api._pm.entry_manager.modify_entry = modify
|
||||
api._pm.parent_seed = "seed"
|
||||
api.app.state.pm.entry_manager.add_ssh_key = add_ssh
|
||||
api.app.state.pm.entry_manager.modify_entry = modify
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
@@ -107,7 +106,7 @@ def test_update_entry_error(client):
|
||||
def modify(*a, **k):
|
||||
raise ValueError("nope")
|
||||
|
||||
api._pm.entry_manager.modify_entry = modify
|
||||
api.app.state.pm.entry_manager.modify_entry = modify
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.put("/api/v1/entry/1", json={"username": "x"}, headers=headers)
|
||||
assert res.status_code == 400
|
||||
@@ -121,7 +120,7 @@ def test_update_config_secret_mode(client):
|
||||
def set_secret(val):
|
||||
called["val"] = val
|
||||
|
||||
api._pm.config_manager.set_secret_mode_enabled = set_secret
|
||||
api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.put(
|
||||
"/api/v1/config/secret_mode_enabled",
|
||||
@@ -135,8 +134,8 @@ def test_update_config_secret_mode(client):
|
||||
|
||||
def test_totp_export_endpoint(client):
|
||||
cl, token = client
|
||||
api._pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]}
|
||||
api._pm.parent_seed = "seed"
|
||||
api.app.state.pm.entry_manager.export_totp_entries = lambda seed: {"entries": ["x"]}
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"}
|
||||
res = cl.get("/api/v1/totp/export", headers=headers)
|
||||
assert res.status_code == 200
|
||||
@@ -145,10 +144,12 @@ def test_totp_export_endpoint(client):
|
||||
|
||||
def test_totp_codes_endpoint(client):
|
||||
cl, token = client
|
||||
api._pm.entry_manager.list_entries = lambda **kw: [(0, "Email", None, None, False)]
|
||||
api._pm.entry_manager.get_totp_code = lambda i, s: "123456"
|
||||
api._pm.entry_manager.get_totp_time_remaining = lambda i: 30
|
||||
api._pm.parent_seed = "seed"
|
||||
api.app.state.pm.entry_manager.list_entries = lambda **kw: [
|
||||
(0, "Email", None, None, False)
|
||||
]
|
||||
api.app.state.pm.entry_manager.get_totp_code = lambda i, s: "123456"
|
||||
api.app.state.pm.entry_manager.get_totp_time_remaining = lambda i: 30
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
headers = {"Authorization": f"Bearer {token}", "X-SeedPass-Password": "pw"}
|
||||
res = cl.get("/api/v1/totp", headers=headers)
|
||||
assert res.status_code == 200
|
||||
@@ -169,11 +170,11 @@ def test_fingerprint_endpoints(client):
|
||||
cl, token = client
|
||||
calls = {}
|
||||
|
||||
api._pm.add_new_fingerprint = lambda: calls.setdefault("add", True)
|
||||
api._pm.fingerprint_manager.remove_fingerprint = lambda fp: calls.setdefault(
|
||||
"remove", fp
|
||||
api.app.state.pm.add_new_fingerprint = lambda: calls.setdefault("add", True)
|
||||
api.app.state.pm.fingerprint_manager.remove_fingerprint = (
|
||||
lambda fp: calls.setdefault("remove", fp)
|
||||
)
|
||||
api._pm.select_fingerprint = lambda fp: calls.setdefault("select", fp)
|
||||
api.app.state.pm.select_fingerprint = lambda fp: calls.setdefault("select", fp)
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
@@ -201,8 +202,10 @@ def test_checksum_endpoints(client):
|
||||
cl, token = client
|
||||
calls = {}
|
||||
|
||||
api._pm.handle_verify_checksum = lambda: calls.setdefault("verify", True)
|
||||
api._pm.handle_update_script_checksum = lambda: calls.setdefault("update", True)
|
||||
api.app.state.pm.handle_verify_checksum = lambda: calls.setdefault("verify", True)
|
||||
api.app.state.pm.handle_update_script_checksum = lambda: calls.setdefault(
|
||||
"update", True
|
||||
)
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
@@ -224,9 +227,11 @@ def test_vault_import_via_path(client, tmp_path):
|
||||
def import_db(path):
|
||||
called["path"] = path
|
||||
|
||||
api._pm.handle_import_database = import_db
|
||||
api._pm.sync_vault = lambda: called.setdefault("sync", True)
|
||||
api._pm.encryption_manager = SimpleNamespace(resolve_relative_path=lambda p: p)
|
||||
api.app.state.pm.handle_import_database = import_db
|
||||
api.app.state.pm.sync_vault = lambda: called.setdefault("sync", True)
|
||||
api.app.state.pm.encryption_manager = SimpleNamespace(
|
||||
resolve_relative_path=lambda p: p
|
||||
)
|
||||
file_path = tmp_path / "b.json.enc"
|
||||
file_path.write_text("{}")
|
||||
|
||||
@@ -249,8 +254,8 @@ def test_vault_import_via_upload(client, tmp_path):
|
||||
def import_db(path):
|
||||
called["path"] = path
|
||||
|
||||
api._pm.handle_import_database = import_db
|
||||
api._pm.sync_vault = lambda: called.setdefault("sync", True)
|
||||
api.app.state.pm.handle_import_database = import_db
|
||||
api.app.state.pm.sync_vault = lambda: called.setdefault("sync", True)
|
||||
file_path = tmp_path / "c.json"
|
||||
file_path.write_text("{}")
|
||||
|
||||
@@ -269,9 +274,11 @@ def test_vault_import_via_upload(client, tmp_path):
|
||||
|
||||
def test_vault_import_invalid_extension(client):
|
||||
cl, token = client
|
||||
api._pm.handle_import_database = lambda path: None
|
||||
api._pm.sync_vault = lambda: None
|
||||
api._pm.encryption_manager = SimpleNamespace(resolve_relative_path=lambda p: p)
|
||||
api.app.state.pm.handle_import_database = lambda path: None
|
||||
api.app.state.pm.sync_vault = lambda: None
|
||||
api.app.state.pm.encryption_manager = SimpleNamespace(
|
||||
resolve_relative_path=lambda p: p
|
||||
)
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
@@ -285,9 +292,9 @@ def test_vault_import_invalid_extension(client):
|
||||
def test_vault_import_path_traversal_blocked(client, tmp_path):
|
||||
cl, token = client
|
||||
key = base64.urlsafe_b64encode(os.urandom(32))
|
||||
api._pm.encryption_manager = EncryptionManager(key, tmp_path)
|
||||
api._pm.handle_import_database = lambda path: None
|
||||
api._pm.sync_vault = lambda: None
|
||||
api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path)
|
||||
api.app.state.pm.handle_import_database = lambda path: None
|
||||
api.app.state.pm.sync_vault = lambda: None
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
@@ -304,20 +311,22 @@ def test_vault_lock_endpoint(client):
|
||||
|
||||
def lock():
|
||||
called["locked"] = True
|
||||
api._pm.locked = True
|
||||
api.app.state.pm.locked = True
|
||||
|
||||
api._pm.lock_vault = lock
|
||||
api._pm.locked = False
|
||||
api.app.state.pm.lock_vault = lock
|
||||
api.app.state.pm.locked = False
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post("/api/v1/vault/lock", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert res.json() == {"status": "locked"}
|
||||
assert called.get("locked") is True
|
||||
assert api._pm.locked is True
|
||||
api._pm.unlock_vault = lambda pw: setattr(api._pm, "locked", False)
|
||||
api._pm.unlock_vault("pw")
|
||||
assert api._pm.locked is False
|
||||
assert api.app.state.pm.locked is True
|
||||
api.app.state.pm.unlock_vault = lambda pw: setattr(
|
||||
api.app.state.pm, "locked", False
|
||||
)
|
||||
api.app.state.pm.unlock_vault("pw")
|
||||
assert api.app.state.pm.locked is False
|
||||
|
||||
|
||||
def test_secret_mode_endpoint(client):
|
||||
@@ -330,8 +339,8 @@ def test_secret_mode_endpoint(client):
|
||||
def set_delay(val):
|
||||
called.setdefault("delay", val)
|
||||
|
||||
api._pm.config_manager.set_secret_mode_enabled = set_secret
|
||||
api._pm.config_manager.set_clipboard_clear_delay = set_delay
|
||||
api.app.state.pm.config_manager.set_secret_mode_enabled = set_secret
|
||||
api.app.state.pm.config_manager.set_clipboard_clear_delay = set_delay
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
@@ -350,7 +359,7 @@ def test_vault_export_endpoint(client, tmp_path):
|
||||
out = tmp_path / "out.json"
|
||||
out.write_text("data")
|
||||
|
||||
api._pm.handle_export_database = lambda: out
|
||||
api.app.state.pm.handle_export_database = lambda: out
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
@@ -366,9 +375,9 @@ def test_vault_export_endpoint(client, tmp_path):
|
||||
|
||||
def test_backup_parent_seed_endpoint(client, tmp_path):
|
||||
cl, token = client
|
||||
api._pm.parent_seed = "seed"
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
called = {}
|
||||
api._pm.encryption_manager = SimpleNamespace(
|
||||
api.app.state.pm.encryption_manager = SimpleNamespace(
|
||||
encrypt_and_save_file=lambda data, path: called.setdefault("path", path),
|
||||
resolve_relative_path=lambda p: p,
|
||||
)
|
||||
@@ -396,9 +405,9 @@ def test_backup_parent_seed_endpoint(client, tmp_path):
|
||||
|
||||
def test_backup_parent_seed_path_traversal_blocked(client, tmp_path):
|
||||
cl, token = client
|
||||
api._pm.parent_seed = "seed"
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
key = base64.urlsafe_b64encode(os.urandom(32))
|
||||
api._pm.encryption_manager = EncryptionManager(key, tmp_path)
|
||||
api.app.state.pm.encryption_manager = EncryptionManager(key, tmp_path)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"X-SeedPass-Password": "pw",
|
||||
@@ -424,8 +433,8 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
|
||||
def set_relays(new, require_pin=False):
|
||||
called["set"] = new
|
||||
|
||||
api._pm.config_manager.load_config = load_config
|
||||
api._pm.config_manager.set_relays = set_relays
|
||||
api.app.state.pm.config_manager.load_config = load_config
|
||||
api.app.state.pm.config_manager.set_relays = set_relays
|
||||
monkeypatch.setattr(
|
||||
NostrClient,
|
||||
"initialize_client_pool",
|
||||
@@ -434,8 +443,8 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
|
||||
monkeypatch.setattr(
|
||||
nostr_client, "close_client_pool", lambda: called.setdefault("close", True)
|
||||
)
|
||||
api._pm.nostr_client = nostr_client
|
||||
api._pm.nostr_client.relays = relays.copy()
|
||||
api.app.state.pm.nostr_client = nostr_client
|
||||
api.app.state.pm.nostr_client.relays = relays.copy()
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
@@ -447,7 +456,7 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
|
||||
assert res.status_code == 200
|
||||
assert called["set"] == ["wss://a", "wss://b", "wss://c"]
|
||||
|
||||
api._pm.config_manager.load_config = lambda require_pin=False: {
|
||||
api.app.state.pm.config_manager.load_config = lambda require_pin=False: {
|
||||
"relays": ["wss://a", "wss://b", "wss://c"]
|
||||
}
|
||||
res = cl.delete("/api/v1/relays/2", headers=headers)
|
||||
@@ -457,7 +466,7 @@ def test_relay_management_endpoints(client, dummy_nostr_client, monkeypatch):
|
||||
res = cl.post("/api/v1/relays/reset", headers=headers)
|
||||
assert res.status_code == 200
|
||||
assert called.get("init") is True
|
||||
assert api._pm.nostr_client.relays == list(DEFAULT_RELAYS)
|
||||
assert api.app.state.pm.nostr_client.relays == list(DEFAULT_RELAYS)
|
||||
|
||||
|
||||
def test_generate_password_no_special_chars(client):
|
||||
@@ -471,8 +480,10 @@ def test_generate_password_no_special_chars(client):
|
||||
def derive_entropy(self, index: int, bytes_len: int, app_no: int = 32) -> bytes:
|
||||
return bytes(range(bytes_len))
|
||||
|
||||
api._pm.password_generator = PasswordGenerator(DummyEnc(), "seed", DummyBIP85())
|
||||
api._pm.parent_seed = "seed"
|
||||
api.app.state.pm.password_generator = PasswordGenerator(
|
||||
DummyEnc(), "seed", DummyBIP85()
|
||||
)
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
res = cl.post(
|
||||
@@ -496,8 +507,10 @@ def test_generate_password_allowed_chars(client):
|
||||
def derive_entropy(self, index: int, bytes_len: int, app_no: int = 32) -> bytes:
|
||||
return bytes((index + i) % 256 for i in range(bytes_len))
|
||||
|
||||
api._pm.password_generator = PasswordGenerator(DummyEnc(), "seed", DummyBIP85())
|
||||
api._pm.parent_seed = "seed"
|
||||
api.app.state.pm.password_generator = PasswordGenerator(
|
||||
DummyEnc(), "seed", DummyBIP85()
|
||||
)
|
||||
api.app.state.pm.parent_seed = "seed"
|
||||
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
allowed = "@$"
|
||||
|
@@ -6,40 +6,42 @@ import seedpass.api as api
|
||||
|
||||
def test_notifications_endpoint(client):
|
||||
cl, token = client
|
||||
api._pm.notifications = queue.Queue()
|
||||
api._pm.notifications.put(SimpleNamespace(message="m1", level="INFO"))
|
||||
api._pm.notifications.put(SimpleNamespace(message="m2", level="WARNING"))
|
||||
api.app.state.pm.notifications = queue.Queue()
|
||||
api.app.state.pm.notifications.put(SimpleNamespace(message="m1", level="INFO"))
|
||||
api.app.state.pm.notifications.put(SimpleNamespace(message="m2", level="WARNING"))
|
||||
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
|
||||
assert res.status_code == 200
|
||||
assert res.json() == [
|
||||
{"level": "INFO", "message": "m1"},
|
||||
{"level": "WARNING", "message": "m2"},
|
||||
]
|
||||
assert api._pm.notifications.empty()
|
||||
assert api.app.state.pm.notifications.empty()
|
||||
|
||||
|
||||
def test_notifications_endpoint_clears_queue(client):
|
||||
cl, token = client
|
||||
api._pm.notifications = queue.Queue()
|
||||
api._pm.notifications.put(SimpleNamespace(message="hi", level="INFO"))
|
||||
api.app.state.pm.notifications = queue.Queue()
|
||||
api.app.state.pm.notifications.put(SimpleNamespace(message="hi", level="INFO"))
|
||||
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
|
||||
assert res.status_code == 200
|
||||
assert res.json() == [{"level": "INFO", "message": "hi"}]
|
||||
assert api._pm.notifications.empty()
|
||||
assert api.app.state.pm.notifications.empty()
|
||||
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
|
||||
assert res.json() == []
|
||||
|
||||
|
||||
def test_notifications_endpoint_does_not_clear_current(client):
|
||||
cl, token = client
|
||||
api._pm.notifications = queue.Queue()
|
||||
api.app.state.pm.notifications = queue.Queue()
|
||||
msg = SimpleNamespace(message="keep", level="INFO")
|
||||
api._pm.notifications.put(msg)
|
||||
api._pm._current_notification = msg
|
||||
api._pm.get_current_notification = lambda: api._pm._current_notification
|
||||
api.app.state.pm.notifications.put(msg)
|
||||
api.app.state.pm._current_notification = msg
|
||||
api.app.state.pm.get_current_notification = (
|
||||
lambda: api.app.state.pm._current_notification
|
||||
)
|
||||
|
||||
res = cl.get("/api/v1/notifications", headers={"Authorization": f"Bearer {token}"})
|
||||
assert res.status_code == 200
|
||||
assert res.json() == [{"level": "INFO", "message": "keep"}]
|
||||
assert api._pm.notifications.empty()
|
||||
assert api._pm.get_current_notification() is msg
|
||||
assert api.app.state.pm.notifications.empty()
|
||||
assert api.app.state.pm.get_current_notification() is msg
|
||||
|
@@ -7,7 +7,7 @@ def test_profile_stats_endpoint(client):
|
||||
# monkeypatch set _pm.get_profile_stats after client fixture started
|
||||
import seedpass.api as api
|
||||
|
||||
api._pm.get_profile_stats = lambda: stats
|
||||
api.app.state.pm.get_profile_stats = lambda: stats
|
||||
res = cl.get("/api/v1/stats", headers={"Authorization": f"Bearer {token}"})
|
||||
assert res.status_code == 200
|
||||
assert res.json() == stats
|
||||
|
Reference in New Issue
Block a user