feat: add short-lived JWT auth and secure endpoints

This commit is contained in:
thePR0M3TH3AN
2025-08-02 21:48:52 -04:00
parent 8c9fe07609
commit 186e39cc91
6 changed files with 104 additions and 16 deletions

View File

@@ -31,6 +31,7 @@ starlette>=0.47.2
httpx>=0.28.1
requests>=2.32
python-multipart
PyJWT
orjson
argon2-cffi
toga-core>=0.5.2

View File

@@ -9,6 +9,9 @@ import secrets
import queue
from typing import Any, List, Optional
from datetime import datetime, timedelta, timezone
import jwt
from fastapi import FastAPI, Header, HTTPException, Request, Response
import asyncio
import sys
@@ -23,10 +26,18 @@ app = FastAPI()
_pm: Optional[PasswordManager] = None
_token: str = ""
_jwt_secret: str = ""
def _check_token(auth: str | None) -> None:
if auth != f"Bearer {_token}":
if auth is None or not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Unauthorized")
token = auth.split(" ", 1)[1]
try:
jwt.decode(token, _jwt_secret, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Unauthorized")
@@ -45,20 +56,21 @@ def _reload_relays(relays: list[str]) -> None:
def start_server(fingerprint: str | None = None) -> str:
"""Initialize global state and return the API token.
"""Initialize global state and return a short-lived JWT token.
Parameters
----------
fingerprint:
Optional seed profile fingerprint to select before starting the server.
"""
global _pm, _token
global _pm, _token, _jwt_secret
if fingerprint is None:
_pm = PasswordManager()
else:
_pm = PasswordManager(fingerprint=fingerprint)
_token = secrets.token_urlsafe(16)
print(f"API token: {_token}")
_jwt_secret = secrets.token_urlsafe(32)
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
origins = [
o.strip()
for o in os.getenv("SEEDPASS_CORS_ORIGINS", "").split(",")
@@ -74,6 +86,12 @@ def start_server(fingerprint: str | None = None) -> str:
return _token
def _require_password(password: str | None) -> None:
assert _pm is not None
if password is None or not _pm.verify_password(password):
raise HTTPException(status_code=401, detail="Invalid password")
@app.get("/api/v1/entry")
def search_entry(query: str, authorization: str | None = Header(None)) -> List[Any]:
_check_token(authorization)
@@ -414,10 +432,13 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
@app.get("/api/v1/parent-seed")
def get_parent_seed(
authorization: str | None = Header(None), file: str | None = None
authorization: str | None = Header(None),
file: str | None = None,
password: str | None = Header(None, alias="X-SeedPass-Password"),
) -> dict:
"""Return the parent seed or save it as an encrypted backup."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
if file:
path = Path(file)
@@ -511,9 +532,13 @@ def update_checksum(authorization: str | None = Header(None)) -> dict[str, str]:
@app.post("/api/v1/vault/export")
def export_vault(authorization: str | None = Header(None)):
def export_vault(
authorization: str | None = Header(None),
password: str | None = Header(None, alias="X-SeedPass-Password"),
):
"""Export the vault and return the encrypted file."""
_check_token(authorization)
_require_password(password)
assert _pm is not None
path = _pm.handle_export_database()
if path is None:

View File

@@ -39,6 +39,7 @@ def client(monkeypatch):
nostr_client=SimpleNamespace(
key_manager=SimpleNamespace(get_npub=lambda: "np")
),
verify_password=lambda pw: True,
)
monkeypatch.setattr(api, "PasswordManager", lambda: dummy)
monkeypatch.setenv("SEEDPASS_CORS_ORIGINS", "http://example.com")

View File

@@ -162,7 +162,10 @@ def test_parent_seed_endpoint(client, tmp_path):
api._pm.encryption_manager = SimpleNamespace(
encrypt_and_save_file=lambda data, path: called.setdefault("path", path)
)
headers = {"Authorization": f"Bearer {token}"}
headers = {
"Authorization": f"Bearer {token}",
"X-SeedPass-Password": "pw",
}
res = cl.get("/api/v1/parent-seed", headers=headers)
assert res.status_code == 200
@@ -174,6 +177,9 @@ def test_parent_seed_endpoint(client, tmp_path):
assert res.json() == {"status": "saved", "path": str(out)}
assert called["path"] == out
res = cl.get("/api/v1/parent-seed", headers={"Authorization": f"Bearer {token}"})
assert res.status_code == 401
def test_fingerprint_endpoints(client):
cl, token = client
@@ -330,11 +336,17 @@ def test_vault_export_endpoint(client, tmp_path):
api._pm.handle_export_database = lambda: out
headers = {"Authorization": f"Bearer {token}"}
headers = {
"Authorization": f"Bearer {token}",
"X-SeedPass-Password": "pw",
}
res = cl.post("/api/v1/vault/export", headers=headers)
assert res.status_code == 200
assert res.content == b"data"
res = cl.post("/api/v1/vault/export", headers={"Authorization": f"Bearer {token}"})
assert res.status_code == 401
def test_backup_parent_seed_endpoint(client, tmp_path):
cl, token = client