mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-09 15:58:48 +00:00
Merge pull request #722 from PR0M3TH3AN/codex/hash-jwt-in-start_server-and-update-checks
Hash JWT token in API
This commit is contained in:
@@ -16,6 +16,7 @@ from fastapi import FastAPI, Header, HTTPException, Request, Response
|
|||||||
import asyncio
|
import asyncio
|
||||||
import sys
|
import sys
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
import hashlib
|
||||||
|
|
||||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
from slowapi.errors import RateLimitExceeded
|
from slowapi.errors import RateLimitExceeded
|
||||||
@@ -49,6 +50,8 @@ def _check_token(auth: str | None) -> None:
|
|||||||
raise HTTPException(status_code=401, detail="Token expired")
|
raise HTTPException(status_code=401, detail="Token expired")
|
||||||
except jwt.InvalidTokenError:
|
except jwt.InvalidTokenError:
|
||||||
raise HTTPException(status_code=401, detail="Unauthorized")
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
if hashlib.sha256(token.encode()).hexdigest() != _token:
|
||||||
|
raise HTTPException(status_code=401, detail="Unauthorized")
|
||||||
|
|
||||||
|
|
||||||
def _reload_relays(relays: list[str]) -> None:
|
def _reload_relays(relays: list[str]) -> None:
|
||||||
@@ -80,7 +83,8 @@ def start_server(fingerprint: str | None = None) -> str:
|
|||||||
_pm = PasswordManager(fingerprint=fingerprint)
|
_pm = PasswordManager(fingerprint=fingerprint)
|
||||||
_jwt_secret = secrets.token_urlsafe(32)
|
_jwt_secret = secrets.token_urlsafe(32)
|
||||||
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
|
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
|
||||||
_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
|
raw_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
|
||||||
|
_token = hashlib.sha256(raw_token.encode()).hexdigest()
|
||||||
if not getattr(app.state, "limiter", None):
|
if not getattr(app.state, "limiter", None):
|
||||||
app.state.limiter = limiter
|
app.state.limiter = limiter
|
||||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
@@ -97,7 +101,7 @@ def start_server(fingerprint: str | None = None) -> str:
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
return _token
|
return raw_token
|
||||||
|
|
||||||
|
|
||||||
def _require_password(password: str | None) -> None:
|
def _require_password(password: str | None) -> None:
|
||||||
|
@@ -4,6 +4,7 @@ import sys
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
|
import hashlib
|
||||||
|
|
||||||
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||||
|
|
||||||
@@ -48,6 +49,12 @@ def client(monkeypatch):
|
|||||||
return client, token
|
return client, token
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_hashed(client):
|
||||||
|
_, token = client
|
||||||
|
assert api._token != token
|
||||||
|
assert api._token == hashlib.sha256(token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def test_cors_and_auth(client):
|
def test_cors_and_auth(client):
|
||||||
cl, token = client
|
cl, token = client
|
||||||
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
headers = {"Authorization": f"Bearer {token}", "Origin": "http://example.com"}
|
||||||
|
Reference in New Issue
Block a user