refactor: move api state to app

This commit is contained in:
thePR0M3TH3AN
2025-08-05 19:14:11 -04:00
parent fa4826fe2d
commit 20ee8a891b
6 changed files with 302 additions and 248 deletions

View File

@@ -36,35 +36,39 @@ _RATE_LIMIT_STR = f"{_RATE_LIMIT}/{_RATE_WINDOW} seconds"
limiter = Limiter(key_func=get_remote_address, default_limits=[_RATE_LIMIT_STR])
app = FastAPI()
_pm: Optional[PasswordManager] = None
_token: str = ""
_jwt_secret: str = ""
def _get_pm(request: Request) -> PasswordManager:
pm = getattr(request.app.state, "pm", None)
assert pm is not None
return pm
def _check_token(auth: str | None) -> None:
def _check_token(request: Request, auth: str | None) -> None:
if auth is None or not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = auth.split(" ", 1)[1]
jwt_secret = getattr(request.app.state, "jwt_secret", "")
token_hash = getattr(request.app.state, "token_hash", "")
try:
jwt.decode(token, _jwt_secret, algorithms=["HS256"])
jwt.decode(token, jwt_secret, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Unauthorized")
if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), _token):
if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), token_hash):
raise HTTPException(status_code=401, detail="Unauthorized")
def _reload_relays(relays: list[str]) -> None:
def _reload_relays(request: Request, relays: list[str]) -> None:
"""Reload the Nostr client with a new relay list."""
assert _pm is not None
pm = _get_pm(request)
try:
_pm.nostr_client.close_client_pool()
pm.nostr_client.close_client_pool()
except Exception:
pass
try:
_pm.nostr_client.relays = relays
_pm.nostr_client.initialize_client_pool()
pm.nostr_client.relays = relays
pm.nostr_client.initialize_client_pool()
except Exception:
pass
@@ -77,15 +81,15 @@ def start_server(fingerprint: str | None = None) -> str:
fingerprint:
Optional seed profile fingerprint to select before starting the server.
"""
global _pm, _token, _jwt_secret
if fingerprint is None:
_pm = PasswordManager()
pm = PasswordManager()
else:
_pm = PasswordManager(fingerprint=fingerprint)
_jwt_secret = secrets.token_urlsafe(32)
pm = PasswordManager(fingerprint=fingerprint)
app.state.pm = pm
app.state.jwt_secret = secrets.token_urlsafe(32)
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
raw_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
_token = hashlib.sha256(raw_token.encode()).hexdigest()
raw_token = jwt.encode(payload, app.state.jwt_secret, algorithm="HS256")
app.state.token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
if not getattr(app.state, "limiter", None):
app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
@@ -105,30 +109,32 @@ def start_server(fingerprint: str | None = None) -> str:
return raw_token
def _require_password(password: str | None) -> None:
assert _pm is not None
if password is None or not _pm.verify_password(password):
def _require_password(request: Request, password: str | None) -> None:
pm = _get_pm(request)
if password is None or not pm.verify_password(password):
raise HTTPException(status_code=401, detail="Invalid password")
def _validate_encryption_path(path: Path) -> Path:
def _validate_encryption_path(request: Request, path: Path) -> Path:
"""Validate and normalize ``path`` within the active fingerprint directory.
Returns the resolved absolute path if validation succeeds.
"""
assert _pm is not None
pm = _get_pm(request)
try:
return _pm.encryption_manager.resolve_relative_path(path)
return pm.encryption_manager.resolve_relative_path(path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/api/v1/entry")
def search_entry(query: str, authorization: str | None = Header(None)) -> List[Any]:
_check_token(authorization)
assert _pm is not None
results = _pm.entry_manager.search_entries(query)
def search_entry(
request: Request, query: str, authorization: str | None = Header(None)
) -> List[Any]:
_check_token(request, authorization)
pm = _get_pm(request)
results = pm.entry_manager.search_entries(query)
return [
{
"id": idx,
@@ -144,14 +150,15 @@ def search_entry(query: str, authorization: str | None = Header(None)) -> List[A
@app.get("/api/v1/entry/{entry_id}")
def get_entry(
request: Request,
entry_id: int,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> Any:
_check_token(authorization)
_require_password(password)
assert _pm is not None
entry = _pm.entry_manager.retrieve_entry(entry_id)
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
entry = pm.entry_manager.retrieve_entry(entry_id)
if entry is None:
raise HTTPException(status_code=404, detail="Not found")
return entry
@@ -159,6 +166,7 @@ def get_entry(
@app.post("/api/v1/entry")
def create_entry(
request: Request,
entry: dict,
authorization: str | None = Header(None),
) -> dict[str, Any]:
@@ -168,8 +176,8 @@ def create_entry(
on, the corresponding entry type is created. When omitted or set to
``password`` the behaviour matches the legacy password-entry API.
"""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
etype = (entry.get("type") or entry.get("kind") or "password").lower()
@@ -186,7 +194,7 @@ def create_entry(
]
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
index = _pm.entry_manager.add_entry(
index = pm.entry_manager.add_entry(
entry.get("label"),
int(entry.get("length", 12)),
entry.get("username"),
@@ -196,11 +204,11 @@ def create_entry(
return {"id": index}
if etype == "totp":
index = _pm.entry_manager.get_next_index()
index = pm.entry_manager.get_next_index()
uri = _pm.entry_manager.add_totp(
uri = pm.entry_manager.add_totp(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
secret=entry.get("secret"),
index=entry.get("index"),
period=int(entry.get("period", 30)),
@@ -211,9 +219,9 @@ def create_entry(
return {"id": index, "uri": uri}
if etype == "ssh":
index = _pm.entry_manager.add_ssh_key(
index = pm.entry_manager.add_ssh_key(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
index=entry.get("index"),
notes=entry.get("notes", ""),
archived=entry.get("archived", False),
@@ -221,9 +229,9 @@ def create_entry(
return {"id": index}
if etype == "pgp":
index = _pm.entry_manager.add_pgp_key(
index = pm.entry_manager.add_pgp_key(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
index=entry.get("index"),
key_type=entry.get("key_type", "ed25519"),
user_id=entry.get("user_id", ""),
@@ -233,9 +241,9 @@ def create_entry(
return {"id": index}
if etype == "nostr":
index = _pm.entry_manager.add_nostr_key(
index = pm.entry_manager.add_nostr_key(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
index=entry.get("index"),
notes=entry.get("notes", ""),
archived=entry.get("archived", False),
@@ -243,7 +251,7 @@ def create_entry(
return {"id": index}
if etype == "key_value":
index = _pm.entry_manager.add_key_value(
index = pm.entry_manager.add_key_value(
entry.get("label"),
entry.get("key"),
entry.get("value"),
@@ -253,13 +261,13 @@ def create_entry(
if etype in {"seed", "managed_account"}:
func = (
_pm.entry_manager.add_seed
pm.entry_manager.add_seed
if etype == "seed"
else _pm.entry_manager.add_managed_account
else pm.entry_manager.add_managed_account
)
index = func(
entry.get("label"),
_pm.parent_seed,
pm.parent_seed,
index=entry.get("index"),
notes=entry.get("notes", ""),
)
@@ -270,6 +278,7 @@ def create_entry(
@app.put("/api/v1/entry/{entry_id}")
def update_entry(
request: Request,
entry_id: int,
entry: dict,
authorization: str | None = Header(None),
@@ -279,10 +288,10 @@ def update_entry(
Additional fields like ``period``, ``digits`` and ``value`` are forwarded for
specialized entry types (e.g. TOTP or key/value entries).
"""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
try:
_pm.entry_manager.modify_entry(
pm.entry_manager.modify_entry(
entry_id,
username=entry.get("username"),
url=entry.get("url"),
@@ -300,31 +309,33 @@ def update_entry(
@app.post("/api/v1/entry/{entry_id}/archive")
def archive_entry(
entry_id: int, authorization: str | None = Header(None)
request: Request, entry_id: int, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Archive an entry."""
_check_token(authorization)
assert _pm is not None
_pm.entry_manager.archive_entry(entry_id)
_check_token(request, authorization)
pm = _get_pm(request)
pm.entry_manager.archive_entry(entry_id)
return {"status": "archived"}
@app.post("/api/v1/entry/{entry_id}/unarchive")
def unarchive_entry(
entry_id: int, authorization: str | None = Header(None)
request: Request, entry_id: int, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Restore an archived entry."""
_check_token(authorization)
assert _pm is not None
_pm.entry_manager.restore_entry(entry_id)
_check_token(request, authorization)
pm = _get_pm(request)
pm.entry_manager.restore_entry(entry_id)
return {"status": "active"}
@app.get("/api/v1/config/{key}")
def get_config(key: str, authorization: str | None = Header(None)) -> Any:
_check_token(authorization)
assert _pm is not None
value = _pm.config_manager.load_config(require_pin=False).get(key)
def get_config(
request: Request, key: str, authorization: str | None = Header(None)
) -> Any:
_check_token(request, authorization)
pm = _get_pm(request)
value = pm.config_manager.load_config(require_pin=False).get(key)
if value is None:
raise HTTPException(status_code=404, detail="Not found")
@@ -333,12 +344,15 @@ def get_config(key: str, authorization: str | None = Header(None)) -> Any:
@app.put("/api/v1/config/{key}")
def update_config(
key: str, data: dict, authorization: str | None = Header(None)
request: Request,
key: str,
data: dict,
authorization: str | None = Header(None),
) -> dict[str, str]:
"""Update a configuration setting."""
_check_token(authorization)
assert _pm is not None
cfg = _pm.config_manager
_check_token(request, authorization)
pm = _get_pm(request)
cfg = pm.config_manager
mapping = {
"relays": lambda v: cfg.set_relays(v, require_pin=False),
"pin": cfg.set_pin,
@@ -364,96 +378,102 @@ def update_config(
@app.post("/api/v1/secret-mode")
def set_secret_mode(
data: dict, authorization: str | None = Header(None)
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Enable/disable secret mode and set the clipboard delay."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
enabled = data.get("enabled")
delay = data.get("delay")
if enabled is None or delay is None:
raise HTTPException(status_code=400, detail="Missing fields")
cfg = _pm.config_manager
cfg = pm.config_manager
cfg.set_secret_mode_enabled(bool(enabled))
cfg.set_clipboard_clear_delay(int(delay))
_pm.secret_mode_enabled = bool(enabled)
_pm.clipboard_clear_delay = int(delay)
pm.secret_mode_enabled = bool(enabled)
pm.clipboard_clear_delay = int(delay)
return {"status": "ok"}
@app.get("/api/v1/fingerprint")
def list_fingerprints(authorization: str | None = Header(None)) -> List[str]:
_check_token(authorization)
assert _pm is not None
return _pm.fingerprint_manager.list_fingerprints()
def list_fingerprints(
request: Request, authorization: str | None = Header(None)
) -> List[str]:
_check_token(request, authorization)
pm = _get_pm(request)
return pm.fingerprint_manager.list_fingerprints()
@app.post("/api/v1/fingerprint")
def add_fingerprint(authorization: str | None = Header(None)) -> dict[str, str]:
def add_fingerprint(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Create a new seed profile."""
_check_token(authorization)
assert _pm is not None
_pm.add_new_fingerprint()
_check_token(request, authorization)
pm = _get_pm(request)
pm.add_new_fingerprint()
return {"status": "ok"}
@app.delete("/api/v1/fingerprint/{fingerprint}")
def remove_fingerprint(
fingerprint: str, authorization: str | None = Header(None)
request: Request, fingerprint: str, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Remove a seed profile."""
_check_token(authorization)
assert _pm is not None
_pm.fingerprint_manager.remove_fingerprint(fingerprint)
_check_token(request, authorization)
pm = _get_pm(request)
pm.fingerprint_manager.remove_fingerprint(fingerprint)
return {"status": "deleted"}
@app.post("/api/v1/fingerprint/select")
def select_fingerprint(
data: dict, authorization: str | None = Header(None)
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Switch the active seed profile."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
fp = data.get("fingerprint")
if not fp:
raise HTTPException(status_code=400, detail="Missing fingerprint")
_pm.select_fingerprint(fp)
pm.select_fingerprint(fp)
return {"status": "ok"}
@app.get("/api/v1/totp/export")
def export_totp(
request: Request,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> dict:
"""Return all stored TOTP entries in JSON format."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
return _pm.entry_manager.export_totp_entries(_pm.parent_seed)
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
return pm.entry_manager.export_totp_entries(pm.parent_seed)
@app.get("/api/v1/totp")
def get_totp_codes(
request: Request,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> dict:
"""Return active TOTP codes with remaining seconds."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
entries = _pm.entry_manager.list_entries(
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
entries = pm.entry_manager.list_entries(
filter_kind=EntryType.TOTP.value, include_archived=False
)
codes = []
for idx, label, _u, _url, _arch in entries:
code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed)
code = pm.entry_manager.get_totp_code(idx, pm.parent_seed)
rem = _pm.entry_manager.get_totp_time_remaining(idx)
rem = pm.entry_manager.get_totp_time_remaining(idx)
codes.append(
{"id": idx, "label": label, "code": code, "seconds_remaining": rem}
@@ -462,23 +482,26 @@ def get_totp_codes(
@app.get("/api/v1/stats")
def get_profile_stats(authorization: str | None = Header(None)) -> dict:
def get_profile_stats(
request: Request, authorization: str | None = Header(None)
) -> dict:
"""Return statistics about the active seed profile."""
_check_token(authorization)
assert _pm is not None
return _pm.get_profile_stats()
_check_token(request, authorization)
pm = _get_pm(request)
return pm.get_profile_stats()
@app.get("/api/v1/notifications")
def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
def get_notifications(
request: Request, authorization: str | None = Header(None)
) -> List[dict]:
"""Return and clear queued notifications."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
notes = []
while True:
try:
note = _pm.notifications.get_nowait()
note = pm.notifications.get_nowait()
except queue.Empty:
break
notes.append({"level": note.level, "message": note.message})
@@ -486,47 +509,51 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
@app.get("/api/v1/nostr/pubkey")
def get_nostr_pubkey(authorization: str | None = Header(None)) -> Any:
_check_token(authorization)
assert _pm is not None
return {"npub": _pm.nostr_client.key_manager.get_npub()}
def get_nostr_pubkey(request: Request, authorization: str | None = Header(None)) -> Any:
_check_token(request, authorization)
pm = _get_pm(request)
return {"npub": pm.nostr_client.key_manager.get_npub()}
@app.get("/api/v1/relays")
def list_relays(authorization: str | None = Header(None)) -> dict:
def list_relays(request: Request, authorization: str | None = Header(None)) -> dict:
"""Return the configured Nostr relays."""
_check_token(authorization)
assert _pm is not None
cfg = _pm.config_manager.load_config(require_pin=False)
_check_token(request, authorization)
pm = _get_pm(request)
cfg = pm.config_manager.load_config(require_pin=False)
return {"relays": cfg.get("relays", [])}
@app.post("/api/v1/relays")
def add_relay(data: dict, authorization: str | None = Header(None)) -> dict[str, str]:
def add_relay(
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Add a relay URL to the configuration."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
url = data.get("url")
if not url:
raise HTTPException(status_code=400, detail="Missing url")
cfg = _pm.config_manager.load_config(require_pin=False)
cfg = pm.config_manager.load_config(require_pin=False)
relays = cfg.get("relays", [])
if url in relays:
raise HTTPException(status_code=400, detail="Relay already present")
relays.append(url)
_pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(relays)
pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(request, relays)
return {"status": "ok"}
@app.delete("/api/v1/relays/{idx}")
def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str, str]:
def remove_relay(
request: Request, idx: int, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Remove a relay by its index (1-based)."""
_check_token(authorization)
assert _pm is not None
cfg = _pm.config_manager.load_config(require_pin=False)
_check_token(request, authorization)
pm = _get_pm(request)
cfg = pm.config_manager.load_config(require_pin=False)
relays = cfg.get("relays", [])
if not (1 <= idx <= len(relays)):
@@ -534,52 +561,59 @@ def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str
if len(relays) == 1:
raise HTTPException(status_code=400, detail="At least one relay required")
relays.pop(idx - 1)
_pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(relays)
pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(request, relays)
return {"status": "ok"}
@app.post("/api/v1/relays/reset")
def reset_relays(authorization: str | None = Header(None)) -> dict[str, str]:
def reset_relays(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Reset relay list to defaults."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
from nostr.client import DEFAULT_RELAYS
relays = list(DEFAULT_RELAYS)
_pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(relays)
pm.config_manager.set_relays(relays, require_pin=False)
_reload_relays(request, relays)
return {"status": "ok"}
@app.post("/api/v1/checksum/verify")
def verify_checksum(authorization: str | None = Header(None)) -> dict[str, str]:
def verify_checksum(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Verify the SeedPass script checksum."""
_check_token(authorization)
assert _pm is not None
_pm.handle_verify_checksum()
_check_token(request, authorization)
pm = _get_pm(request)
pm.handle_verify_checksum()
return {"status": "ok"}
@app.post("/api/v1/checksum/update")
def update_checksum(authorization: str | None = Header(None)) -> dict[str, str]:
def update_checksum(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Regenerate the script checksum file."""
_check_token(authorization)
assert _pm is not None
_pm.handle_update_script_checksum()
_check_token(request, authorization)
pm = _get_pm(request)
pm.handle_update_script_checksum()
return {"status": "ok"}
@app.post("/api/v1/vault/export")
def export_vault(
request: Request,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
):
"""Export the vault and return the encrypted file."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
path = _pm.handle_export_database()
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
path = pm.handle_export_database()
if path is None:
raise HTTPException(status_code=500, detail="Export failed")
data = Path(path).read_bytes()
@@ -591,8 +625,8 @@ async def import_vault(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Import a vault backup from a file upload or a server path."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
ctype = request.headers.get("content-type", "")
@@ -607,7 +641,7 @@ async def import_vault(
tmp.write(data)
tmp_path = Path(tmp.name)
try:
_pm.handle_import_database(tmp_path)
pm.handle_import_database(tmp_path)
finally:
os.unlink(tmp_path)
else:
@@ -617,28 +651,29 @@ async def import_vault(
if not path_str:
raise HTTPException(status_code=400, detail="Missing file or path")
path = _validate_encryption_path(Path(path_str))
path = _validate_encryption_path(request, Path(path_str))
if not str(path).endswith(".json.enc"):
raise HTTPException(
status_code=400,
detail="Selected file must be a '.json.enc' backup",
)
_pm.handle_import_database(path)
_pm.sync_vault()
pm.handle_import_database(path)
pm.sync_vault()
return {"status": "ok"}
@app.post("/api/v1/vault/backup-parent-seed")
def backup_parent_seed(
request: Request,
data: dict,
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> dict[str, str]:
"""Create an encrypted backup of the parent seed after confirmation."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
_check_token(request, authorization)
_require_password(request, password)
pm = _get_pm(request)
if not data.get("confirm"):
@@ -649,30 +684,30 @@ def backup_parent_seed(
if not path_str:
raise HTTPException(status_code=400, detail="Missing path")
path = Path(path_str)
_validate_encryption_path(path)
_pm.encryption_manager.encrypt_and_save_file(_pm.parent_seed.encode("utf-8"), path)
_validate_encryption_path(request, path)
pm.encryption_manager.encrypt_and_save_file(pm.parent_seed.encode("utf-8"), path)
return {"status": "saved", "path": str(path)}
@app.post("/api/v1/change-password")
def change_password(
data: dict, authorization: str | None = Header(None)
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Change the master password for the active profile."""
_check_token(authorization)
assert _pm is not None
_pm.change_password(data.get("old", ""), data.get("new", ""))
_check_token(request, authorization)
pm = _get_pm(request)
pm.change_password(data.get("old", ""), data.get("new", ""))
return {"status": "ok"}
@app.post("/api/v1/password")
def generate_password(
data: dict, authorization: str | None = Header(None)
request: Request, data: dict, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Generate a password using optional policy overrides."""
_check_token(authorization)
assert _pm is not None
_check_token(request, authorization)
pm = _get_pm(request)
length = int(data.get("length", 12))
policy_keys = [
@@ -687,23 +722,27 @@ def generate_password(
]
kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None}
util = UtilityService(_pm)
util = UtilityService(pm)
password = util.generate_password(length, **kwargs)
return {"password": password}
@app.post("/api/v1/vault/lock")
def lock_vault(authorization: str | None = Header(None)) -> dict[str, str]:
def lock_vault(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
"""Lock the vault and clear sensitive data from memory."""
_check_token(authorization)
assert _pm is not None
_pm.lock_vault()
_check_token(request, authorization)
pm = _get_pm(request)
pm.lock_vault()
return {"status": "locked"}
@app.post("/api/v1/shutdown")
async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]:
_check_token(authorization)
async def shutdown_server(
request: Request, authorization: str | None = Header(None)
) -> dict[str, str]:
_check_token(request, authorization)
asyncio.get_event_loop().call_soon(sys.exit, 0)
return {"status": "shutting down"}