Switch API token to bcrypt

This commit is contained in:
thePR0M3TH3AN
2025-08-20 19:29:08 -04:00
parent da37ec2e61
commit d99af30d9f
4 changed files with 21 additions and 26 deletions

View File

@@ -127,7 +127,7 @@ Run or stop the local HTTP API.
| Action | Command | Examples | | Action | Command | Examples |
| :--- | :--- | :--- | | :--- | :--- | :--- |
| Start the API | `api start` | `seedpass api start --host 0.0.0.0 --port 8000` | | Start the API | `api start` | `seedpass api start --host 0.0.0.0 --port 8000` |
| Stop the API | `api stop` | `seedpass api stop` | | Stop the API | `api stop --token TOKEN` | `seedpass api stop --token <token>` |
--- ---
@@ -214,7 +214,7 @@ Set the `SEEDPASS_CORS_ORIGINS` environment variable to a commaseparated list
SEEDPASS_CORS_ORIGINS=http://localhost:3000 seedpass api start SEEDPASS_CORS_ORIGINS=http://localhost:3000 seedpass api start
``` ```
Shut down the server with `seedpass api stop`. Shut down the server with `seedpass api stop --token <token>`.
--- ---

View File

@@ -9,8 +9,6 @@ import secrets
import queue import queue
from typing import Any, List, Optional from typing import Any, List, Optional
from datetime import datetime, timedelta, timezone
import jwt
import logging import logging
from fastapi import FastAPI, Header, HTTPException, Request, Response from fastapi import FastAPI, Header, HTTPException, Request, Response
@@ -18,8 +16,8 @@ from fastapi.concurrency import run_in_threadpool
import asyncio import asyncio
import sys import sys
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
import hashlib
import hmac import bcrypt
from slowapi import Limiter, _rate_limit_exceeded_handler from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded from slowapi.errors import RateLimitExceeded
@@ -50,16 +48,9 @@ def _get_pm(request: Request) -> PasswordManager:
def _check_token(request: Request, auth: str | None) -> None: def _check_token(request: Request, auth: str | None) -> None:
if auth is None or not auth.startswith("Bearer "): if auth is None or not auth.startswith("Bearer "):
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
token = auth.split(" ", 1)[1] token = auth.split(" ", 1)[1].encode()
jwt_secret = getattr(request.app.state, "jwt_secret", "") token_hash = getattr(request.app.state, "token_hash", b"")
token_hash = getattr(request.app.state, "token_hash", "") if not token_hash or not bcrypt.checkpw(token, token_hash):
try:
jwt.decode(token, jwt_secret, algorithms=["HS256"])
except jwt.ExpiredSignatureError:
raise HTTPException(status_code=401, detail="Token expired")
except jwt.InvalidTokenError:
raise HTTPException(status_code=401, detail="Unauthorized")
if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), token_hash):
raise HTTPException(status_code=401, detail="Unauthorized") raise HTTPException(status_code=401, detail="Unauthorized")
@@ -78,7 +69,7 @@ def _reload_relays(request: Request, relays: list[str]) -> None:
def start_server(fingerprint: str | None = None) -> str: def start_server(fingerprint: str | None = None) -> str:
"""Initialize global state and return a short-lived JWT token. """Initialize global state and return a random API token.
Parameters Parameters
---------- ----------
@@ -90,10 +81,8 @@ def start_server(fingerprint: str | None = None) -> str:
else: else:
pm = PasswordManager(fingerprint=fingerprint) pm = PasswordManager(fingerprint=fingerprint)
app.state.pm = pm app.state.pm = pm
app.state.jwt_secret = secrets.token_urlsafe(32) raw_token = secrets.token_urlsafe(32)
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)} app.state.token_hash = bcrypt.hashpw(raw_token.encode(), bcrypt.gensalt())
raw_token = jwt.encode(payload, app.state.jwt_secret, algorithm="HS256")
app.state.token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
if not getattr(app.state, "limiter", None): if not getattr(app.state, "limiter", None):
app.state.limiter = limiter app.state.limiter = limiter
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler) app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)

View File

@@ -13,19 +13,25 @@ app = typer.Typer(help="Run the API server")
def api_start(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> None: def api_start(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> None:
"""Start the SeedPass API server.""" """Start the SeedPass API server."""
token = api_module.start_server(ctx.obj.get("fingerprint")) token = api_module.start_server(ctx.obj.get("fingerprint"))
typer.echo(f"API token: {token}") typer.echo(
f"API token: {token}\nWARNING: Store this token securely; it cannot be recovered."
)
uvicorn.run(api_module.app, host=host, port=port) uvicorn.run(api_module.app, host=host, port=port)
@app.command("stop") @app.command("stop")
def api_stop(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> None: def api_stop(
token: str = typer.Option(..., help="API token"),
host: str = "127.0.0.1",
port: int = 8000,
) -> None:
"""Stop the SeedPass API server.""" """Stop the SeedPass API server."""
import requests import requests
try: try:
requests.post( requests.post(
f"http://{host}:{port}/api/v1/shutdown", f"http://{host}:{port}/api/v1/shutdown",
headers={"Authorization": f"Bearer {api_module.app.state.token_hash}"}, headers={"Authorization": f"Bearer {token}"},
timeout=2, timeout=2,
) )
except Exception as exc: # pragma: no cover - best effort except Exception as exc: # pragma: no cover - best effort

View File

@@ -4,7 +4,7 @@ import sys
import pytest import pytest
from httpx import ASGITransport, AsyncClient from httpx import ASGITransport, AsyncClient
import hashlib import bcrypt
sys.path.append(str(Path(__file__).resolve().parents[1])) sys.path.append(str(Path(__file__).resolve().parents[1]))
@@ -54,7 +54,7 @@ async def client(monkeypatch):
async def test_token_hashed(client): async def test_token_hashed(client):
_, token = client _, token = client
assert api.app.state.token_hash != token assert api.app.state.token_hash != token
assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest() assert bcrypt.checkpw(token.encode(), api.app.state.token_hash)
@pytest.mark.anyio @pytest.mark.anyio