mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-08 07:18:47 +00:00
Merge pull request #721 from PR0M3TH3AN/codex/add-rate-limiting-to-api
Add request rate limiting to API
This commit is contained in:
@@ -37,3 +37,4 @@ argon2-cffi
|
||||
toga-core>=0.5.2
|
||||
pillow
|
||||
toga-dummy>=0.5.2 # for headless GUI tests
|
||||
slowapi
|
||||
|
@@ -17,11 +17,21 @@ import asyncio
|
||||
import sys
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
from slowapi.middleware import SlowAPIMiddleware
|
||||
|
||||
from seedpass.core.manager import PasswordManager
|
||||
from seedpass.core.entry_types import EntryType
|
||||
from seedpass.core.api import UtilityService
|
||||
|
||||
|
||||
_RATE_LIMIT = int(os.getenv("SEEDPASS_RATE_LIMIT", "100"))
|
||||
_RATE_WINDOW = int(os.getenv("SEEDPASS_RATE_WINDOW", "60"))
|
||||
_RATE_LIMIT_STR = f"{_RATE_LIMIT}/{_RATE_WINDOW} seconds"
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address, default_limits=[_RATE_LIMIT_STR])
|
||||
app = FastAPI()
|
||||
|
||||
_pm: Optional[PasswordManager] = None
|
||||
@@ -71,6 +81,10 @@ def start_server(fingerprint: str | None = None) -> str:
|
||||
_jwt_secret = secrets.token_urlsafe(32)
|
||||
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
|
||||
_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
|
||||
if not getattr(app.state, "limiter", None):
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
app.add_middleware(SlowAPIMiddleware)
|
||||
origins = [
|
||||
o.strip()
|
||||
for o in os.getenv("SEEDPASS_CORS_ORIGINS", "").split(",")
|
||||
@@ -157,6 +171,7 @@ def create_entry(
|
||||
"min_special",
|
||||
]
|
||||
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
|
||||
|
||||
index = _pm.entry_manager.add_entry(
|
||||
entry.get("label"),
|
||||
int(entry.get("length", 12)),
|
||||
@@ -168,6 +183,7 @@ def create_entry(
|
||||
|
||||
if etype == "totp":
|
||||
index = _pm.entry_manager.get_next_index()
|
||||
|
||||
uri = _pm.entry_manager.add_totp(
|
||||
entry.get("label"),
|
||||
_pm.parent_seed,
|
||||
@@ -295,6 +311,7 @@ def get_config(key: str, authorization: str | None = Header(None)) -> Any:
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
value = _pm.config_manager.load_config(require_pin=False).get(key)
|
||||
|
||||
if value is None:
|
||||
raise HTTPException(status_code=404, detail="Not found")
|
||||
return {"key": key, "value": value}
|
||||
@@ -320,6 +337,7 @@ def update_config(
|
||||
}
|
||||
|
||||
action = mapping.get(key)
|
||||
|
||||
if action is None:
|
||||
raise HTTPException(status_code=400, detail="Unknown key")
|
||||
|
||||
@@ -338,7 +356,9 @@ def set_secret_mode(
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
enabled = data.get("enabled")
|
||||
|
||||
delay = data.get("delay")
|
||||
|
||||
if enabled is None or delay is None:
|
||||
raise HTTPException(status_code=400, detail="Missing fields")
|
||||
cfg = _pm.config_manager
|
||||
@@ -384,6 +404,7 @@ def select_fingerprint(
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
fp = data.get("fingerprint")
|
||||
|
||||
if not fp:
|
||||
raise HTTPException(status_code=400, detail="Missing fingerprint")
|
||||
_pm.select_fingerprint(fp)
|
||||
@@ -409,7 +430,9 @@ def get_totp_codes(authorization: str | None = Header(None)) -> dict:
|
||||
codes = []
|
||||
for idx, label, _u, _url, _arch in entries:
|
||||
code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed)
|
||||
|
||||
rem = _pm.entry_manager.get_totp_time_remaining(idx)
|
||||
|
||||
codes.append(
|
||||
{"id": idx, "label": label, "code": code, "seconds_remaining": rem}
|
||||
)
|
||||
@@ -433,6 +456,7 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
|
||||
while True:
|
||||
try:
|
||||
note = _pm.notifications.get_nowait()
|
||||
|
||||
except queue.Empty:
|
||||
break
|
||||
notes.append({"level": note.level, "message": note.message})
|
||||
@@ -461,10 +485,12 @@ def add_relay(data: dict, authorization: str | None = Header(None)) -> dict[str,
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
url = data.get("url")
|
||||
|
||||
if not url:
|
||||
raise HTTPException(status_code=400, detail="Missing url")
|
||||
cfg = _pm.config_manager.load_config(require_pin=False)
|
||||
relays = cfg.get("relays", [])
|
||||
|
||||
if url in relays:
|
||||
raise HTTPException(status_code=400, detail="Relay already present")
|
||||
relays.append(url)
|
||||
@@ -480,6 +506,7 @@ def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str
|
||||
assert _pm is not None
|
||||
cfg = _pm.config_manager.load_config(require_pin=False)
|
||||
relays = cfg.get("relays", [])
|
||||
|
||||
if not (1 <= idx <= len(relays)):
|
||||
raise HTTPException(status_code=400, detail="Invalid index")
|
||||
if len(relays) == 1:
|
||||
@@ -546,9 +573,11 @@ async def import_vault(
|
||||
assert _pm is not None
|
||||
|
||||
ctype = request.headers.get("content-type", "")
|
||||
|
||||
if ctype.startswith("multipart/form-data"):
|
||||
form = await request.form()
|
||||
file = form.get("file")
|
||||
|
||||
if file is None:
|
||||
raise HTTPException(status_code=400, detail="Missing file")
|
||||
data = await file.read()
|
||||
@@ -562,6 +591,7 @@ async def import_vault(
|
||||
else:
|
||||
body = await request.json()
|
||||
path = body.get("path")
|
||||
|
||||
if not path:
|
||||
raise HTTPException(status_code=400, detail="Missing file or path")
|
||||
_pm.handle_import_database(Path(path))
|
||||
@@ -581,9 +611,11 @@ def backup_parent_seed(
|
||||
assert _pm is not None
|
||||
|
||||
if not data.get("confirm"):
|
||||
|
||||
raise HTTPException(status_code=400, detail="Confirmation required")
|
||||
|
||||
path_str = data.get("path")
|
||||
|
||||
if not path_str:
|
||||
raise HTTPException(status_code=400, detail="Missing path")
|
||||
path = Path(path_str)
|
||||
@@ -600,6 +632,7 @@ def change_password(
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
_pm.change_password(data.get("old", ""), data.get("new", ""))
|
||||
|
||||
return {"status": "ok"}
|
||||
|
||||
|
||||
@@ -611,6 +644,7 @@ def generate_password(
|
||||
_check_token(authorization)
|
||||
assert _pm is not None
|
||||
length = int(data.get("length", 12))
|
||||
|
||||
policy_keys = [
|
||||
"include_special_chars",
|
||||
"allowed_special_chars",
|
||||
@@ -622,6 +656,7 @@ def generate_password(
|
||||
"min_special",
|
||||
]
|
||||
kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None}
|
||||
|
||||
util = UtilityService(_pm)
|
||||
password = util.generate_password(length, **kwargs)
|
||||
return {"password": password}
|
||||
@@ -640,4 +675,5 @@ def lock_vault(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||
_check_token(authorization)
|
||||
asyncio.get_event_loop().call_soon(sys.exit, 0)
|
||||
|
||||
return {"status": "shutting down"}
|
||||
|
42
src/tests/test_api_rate_limit.py
Normal file
42
src/tests/test_api_rate_limit.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import importlib
|
||||
from pathlib import Path
|
||||
from types import SimpleNamespace
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
import sys
|
||||
|
||||
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||
|
||||
|
||||
def test_rate_limit_exceeded(monkeypatch):
|
||||
monkeypatch.setenv("SEEDPASS_RATE_LIMIT", "2")
|
||||
monkeypatch.setenv("SEEDPASS_RATE_WINDOW", "60")
|
||||
import seedpass.api as api
|
||||
|
||||
importlib.reload(api)
|
||||
|
||||
dummy = SimpleNamespace(
|
||||
entry_manager=SimpleNamespace(
|
||||
search_entries=lambda q: [
|
||||
(1, "Site", "user", "url", False, SimpleNamespace(value="password"))
|
||||
]
|
||||
),
|
||||
config_manager=SimpleNamespace(load_config=lambda require_pin=False: {}),
|
||||
fingerprint_manager=SimpleNamespace(list_fingerprints=lambda: []),
|
||||
nostr_client=SimpleNamespace(
|
||||
key_manager=SimpleNamespace(get_npub=lambda: "np")
|
||||
),
|
||||
verify_password=lambda pw: True,
|
||||
)
|
||||
monkeypatch.setattr(api, "PasswordManager", lambda: dummy)
|
||||
token = api.start_server()
|
||||
client = TestClient(api.app)
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
|
||||
for _ in range(2):
|
||||
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||
assert res.status_code == 200
|
||||
|
||||
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||
assert res.status_code == 429
|
Reference in New Issue
Block a user