mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-09 15:58:48 +00:00
Merge pull request #834 from PR0M3TH3AN/codex/implement-token-generation-and-validation
Use bcrypt-hashed API tokens
This commit is contained in:
@@ -127,7 +127,7 @@ Run or stop the local HTTP API.
|
|||||||
| Action | Command | Examples |
|
| Action | Command | Examples |
|
||||||
| :--- | :--- | :--- |
|
| :--- | :--- | :--- |
|
||||||
| Start the API | `api start` | `seedpass api start --host 0.0.0.0 --port 8000` |
|
| Start the API | `api start` | `seedpass api start --host 0.0.0.0 --port 8000` |
|
||||||
| Stop the API | `api stop` | `seedpass api stop` |
|
| Stop the API | `api stop --token TOKEN` | `seedpass api stop --token <token>` |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -214,7 +214,7 @@ Set the `SEEDPASS_CORS_ORIGINS` environment variable to a comma‑separated list
|
|||||||
SEEDPASS_CORS_ORIGINS=http://localhost:3000 seedpass api start
|
SEEDPASS_CORS_ORIGINS=http://localhost:3000 seedpass api start
|
||||||
```
|
```
|
||||||
|
|
||||||
Shut down the server with `seedpass api stop`.
|
Shut down the server with `seedpass api stop --token <token>`.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
|
@@ -9,8 +9,6 @@ import secrets
|
|||||||
import queue
|
import queue
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
import jwt
|
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
from fastapi import FastAPI, Header, HTTPException, Request, Response
|
from fastapi import FastAPI, Header, HTTPException, Request, Response
|
||||||
@@ -18,8 +16,8 @@ from fastapi.concurrency import run_in_threadpool
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
import hashlib
|
|
||||||
import hmac
|
import bcrypt
|
||||||
|
|
||||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
@@ -50,16 +48,9 @@ def _get_pm(request: Request) -> PasswordManager:
|
|||||||
def _check_token(request: Request, auth: str | None) -> None:
|
def _check_token(request: Request, auth: str | None) -> None:
|
||||||
if auth is None or not auth.startswith("Bearer "):
|
if auth is None or not auth.startswith("Bearer "):
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
token = auth.split(" ", 1)[1]
|
token = auth.split(" ", 1)[1].encode()
|
||||||
jwt_secret = getattr(request.app.state, "jwt_secret", "")
|
token_hash = getattr(request.app.state, "token_hash", b"")
|
||||||
token_hash = getattr(request.app.state, "token_hash", "")
|
if not token_hash or not bcrypt.checkpw(token, token_hash):
|
||||||
try:
|
|
||||||
jwt.decode(token, jwt_secret, algorithms=["HS256"])
|
|
||||||
except jwt.ExpiredSignatureError:
|
|
||||||
raise HTTPException(status_code=401, detail="Token expired")
|
|
||||||
except jwt.InvalidTokenError:
|
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
||||||
if not hmac.compare_digest(hashlib.sha256(token.encode()).hexdigest(), token_hash):
|
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
|
||||||
@@ -78,7 +69,7 @@ def _reload_relays(request: Request, relays: list[str]) -> None:
|
|||||||
|
|
||||||
|
|
||||||
def start_server(fingerprint: str | None = None) -> str:
|
def start_server(fingerprint: str | None = None) -> str:
|
||||||
"""Initialize global state and return a short-lived JWT token.
|
"""Initialize global state and return a random API token.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
@@ -90,10 +81,8 @@ def start_server(fingerprint: str | None = None) -> str:
|
|||||||
else:
|
else:
|
||||||
pm = PasswordManager(fingerprint=fingerprint)
|
pm = PasswordManager(fingerprint=fingerprint)
|
||||||
app.state.pm = pm
|
app.state.pm = pm
|
||||||
app.state.jwt_secret = secrets.token_urlsafe(32)
|
raw_token = secrets.token_urlsafe(32)
|
||||||
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
|
app.state.token_hash = bcrypt.hashpw(raw_token.encode(), bcrypt.gensalt())
|
||||||
raw_token = jwt.encode(payload, app.state.jwt_secret, algorithm="HS256")
|
|
||||||
app.state.token_hash = hashlib.sha256(raw_token.encode()).hexdigest()
|
|
||||||
if not getattr(app.state, "limiter", None):
|
if not getattr(app.state, "limiter", None):
|
||||||
app.state.limiter = limiter
|
app.state.limiter = limiter
|
||||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
@@ -13,19 +13,25 @@ app = typer.Typer(help="Run the API server")
|
|||||||
def api_start(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> None:
|
def api_start(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> None:
|
||||||
"""Start the SeedPass API server."""
|
"""Start the SeedPass API server."""
|
||||||
token = api_module.start_server(ctx.obj.get("fingerprint"))
|
token = api_module.start_server(ctx.obj.get("fingerprint"))
|
||||||
typer.echo(f"API token: {token}")
|
typer.echo(
|
||||||
|
f"API token: {token}\nWARNING: Store this token securely; it cannot be recovered."
|
||||||
|
)
|
||||||
uvicorn.run(api_module.app, host=host, port=port)
|
uvicorn.run(api_module.app, host=host, port=port)
|
||||||
|
|
||||||
|
|
||||||
@app.command("stop")
|
@app.command("stop")
|
||||||
def api_stop(ctx: typer.Context, host: str = "127.0.0.1", port: int = 8000) -> None:
|
def api_stop(
|
||||||
|
token: str = typer.Option(..., help="API token"),
|
||||||
|
host: str = "127.0.0.1",
|
||||||
|
port: int = 8000,
|
||||||
|
) -> None:
|
||||||
"""Stop the SeedPass API server."""
|
"""Stop the SeedPass API server."""
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
try:
|
try:
|
||||||
requests.post(
|
requests.post(
|
||||||
f"http://{host}:{port}/api/v1/shutdown",
|
f"http://{host}:{port}/api/v1/shutdown",
|
||||||
headers={"Authorization": f"Bearer {api_module.app.state.token_hash}"},
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
timeout=2,
|
timeout=2,
|
||||||
)
|
)
|
||||||
except Exception as exc: # pragma: no cover - best effort
|
except Exception as exc: # pragma: no cover - best effort
|
||||||
|
@@ -4,7 +4,7 @@ import sys
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from httpx import ASGITransport, AsyncClient
|
from httpx import ASGITransport, AsyncClient
|
||||||
import hashlib
|
import bcrypt
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||||
|
|
||||||
@@ -54,7 +54,7 @@ async def client(monkeypatch):
|
|||||||
async def test_token_hashed(client):
|
async def test_token_hashed(client):
|
||||||
_, token = client
|
_, token = client
|
||||||
assert api.app.state.token_hash != token
|
assert api.app.state.token_hash != token
|
||||||
assert api.app.state.token_hash == hashlib.sha256(token.encode()).hexdigest()
|
assert bcrypt.checkpw(token.encode(), api.app.state.token_hash)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.anyio
|
@pytest.mark.anyio
|
||||||
|
Reference in New Issue
Block a user