mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-09 07:48:57 +00:00
Merge pull request #721 from PR0M3TH3AN/codex/add-rate-limiting-to-api
Add request rate limiting to API
This commit is contained in:
@@ -37,3 +37,4 @@ argon2-cffi
|
|||||||
toga-core>=0.5.2
|
toga-core>=0.5.2
|
||||||
pillow
|
pillow
|
||||||
toga-dummy>=0.5.2 # for headless GUI tests
|
toga-dummy>=0.5.2 # for headless GUI tests
|
||||||
|
slowapi
|
||||||
|
@@ -17,11 +17,21 @@ import asyncio
|
|||||||
import sys
|
import sys
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||||
|
from slowapi.errors import RateLimitExceeded
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from slowapi.middleware import SlowAPIMiddleware
|
||||||
|
|
||||||
from seedpass.core.manager import PasswordManager
|
from seedpass.core.manager import PasswordManager
|
||||||
from seedpass.core.entry_types import EntryType
|
from seedpass.core.entry_types import EntryType
|
||||||
from seedpass.core.api import UtilityService
|
from seedpass.core.api import UtilityService
|
||||||
|
|
||||||
|
|
||||||
|
_RATE_LIMIT = int(os.getenv("SEEDPASS_RATE_LIMIT", "100"))
|
||||||
|
_RATE_WINDOW = int(os.getenv("SEEDPASS_RATE_WINDOW", "60"))
|
||||||
|
_RATE_LIMIT_STR = f"{_RATE_LIMIT}/{_RATE_WINDOW} seconds"
|
||||||
|
|
||||||
|
limiter = Limiter(key_func=get_remote_address, default_limits=[_RATE_LIMIT_STR])
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
|
|
||||||
_pm: Optional[PasswordManager] = None
|
_pm: Optional[PasswordManager] = None
|
||||||
@@ -71,6 +81,10 @@ def start_server(fingerprint: str | None = None) -> str:
|
|||||||
_jwt_secret = secrets.token_urlsafe(32)
|
_jwt_secret = secrets.token_urlsafe(32)
|
||||||
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
|
payload = {"exp": datetime.now(timezone.utc) + timedelta(minutes=5)}
|
||||||
_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
|
_token = jwt.encode(payload, _jwt_secret, algorithm="HS256")
|
||||||
|
if not getattr(app.state, "limiter", None):
|
||||||
|
app.state.limiter = limiter
|
||||||
|
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||||
|
app.add_middleware(SlowAPIMiddleware)
|
||||||
origins = [
|
origins = [
|
||||||
o.strip()
|
o.strip()
|
||||||
for o in os.getenv("SEEDPASS_CORS_ORIGINS", "").split(",")
|
for o in os.getenv("SEEDPASS_CORS_ORIGINS", "").split(",")
|
||||||
@@ -157,6 +171,7 @@ def create_entry(
|
|||||||
"min_special",
|
"min_special",
|
||||||
]
|
]
|
||||||
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
|
kwargs = {k: entry.get(k) for k in policy_keys if entry.get(k) is not None}
|
||||||
|
|
||||||
index = _pm.entry_manager.add_entry(
|
index = _pm.entry_manager.add_entry(
|
||||||
entry.get("label"),
|
entry.get("label"),
|
||||||
int(entry.get("length", 12)),
|
int(entry.get("length", 12)),
|
||||||
@@ -168,6 +183,7 @@ def create_entry(
|
|||||||
|
|
||||||
if etype == "totp":
|
if etype == "totp":
|
||||||
index = _pm.entry_manager.get_next_index()
|
index = _pm.entry_manager.get_next_index()
|
||||||
|
|
||||||
uri = _pm.entry_manager.add_totp(
|
uri = _pm.entry_manager.add_totp(
|
||||||
entry.get("label"),
|
entry.get("label"),
|
||||||
_pm.parent_seed,
|
_pm.parent_seed,
|
||||||
@@ -295,6 +311,7 @@ def get_config(key: str, authorization: str | None = Header(None)) -> Any:
|
|||||||
_check_token(authorization)
|
_check_token(authorization)
|
||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
value = _pm.config_manager.load_config(require_pin=False).get(key)
|
value = _pm.config_manager.load_config(require_pin=False).get(key)
|
||||||
|
|
||||||
if value is None:
|
if value is None:
|
||||||
raise HTTPException(status_code=404, detail="Not found")
|
raise HTTPException(status_code=404, detail="Not found")
|
||||||
return {"key": key, "value": value}
|
return {"key": key, "value": value}
|
||||||
@@ -320,6 +337,7 @@ def update_config(
|
|||||||
}
|
}
|
||||||
|
|
||||||
action = mapping.get(key)
|
action = mapping.get(key)
|
||||||
|
|
||||||
if action is None:
|
if action is None:
|
||||||
raise HTTPException(status_code=400, detail="Unknown key")
|
raise HTTPException(status_code=400, detail="Unknown key")
|
||||||
|
|
||||||
@@ -338,7 +356,9 @@ def set_secret_mode(
|
|||||||
_check_token(authorization)
|
_check_token(authorization)
|
||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
enabled = data.get("enabled")
|
enabled = data.get("enabled")
|
||||||
|
|
||||||
delay = data.get("delay")
|
delay = data.get("delay")
|
||||||
|
|
||||||
if enabled is None or delay is None:
|
if enabled is None or delay is None:
|
||||||
raise HTTPException(status_code=400, detail="Missing fields")
|
raise HTTPException(status_code=400, detail="Missing fields")
|
||||||
cfg = _pm.config_manager
|
cfg = _pm.config_manager
|
||||||
@@ -384,6 +404,7 @@ def select_fingerprint(
|
|||||||
_check_token(authorization)
|
_check_token(authorization)
|
||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
fp = data.get("fingerprint")
|
fp = data.get("fingerprint")
|
||||||
|
|
||||||
if not fp:
|
if not fp:
|
||||||
raise HTTPException(status_code=400, detail="Missing fingerprint")
|
raise HTTPException(status_code=400, detail="Missing fingerprint")
|
||||||
_pm.select_fingerprint(fp)
|
_pm.select_fingerprint(fp)
|
||||||
@@ -409,7 +430,9 @@ def get_totp_codes(authorization: str | None = Header(None)) -> dict:
|
|||||||
codes = []
|
codes = []
|
||||||
for idx, label, _u, _url, _arch in entries:
|
for idx, label, _u, _url, _arch in entries:
|
||||||
code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed)
|
code = _pm.entry_manager.get_totp_code(idx, _pm.parent_seed)
|
||||||
|
|
||||||
rem = _pm.entry_manager.get_totp_time_remaining(idx)
|
rem = _pm.entry_manager.get_totp_time_remaining(idx)
|
||||||
|
|
||||||
codes.append(
|
codes.append(
|
||||||
{"id": idx, "label": label, "code": code, "seconds_remaining": rem}
|
{"id": idx, "label": label, "code": code, "seconds_remaining": rem}
|
||||||
)
|
)
|
||||||
@@ -433,6 +456,7 @@ def get_notifications(authorization: str | None = Header(None)) -> List[dict]:
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
note = _pm.notifications.get_nowait()
|
note = _pm.notifications.get_nowait()
|
||||||
|
|
||||||
except queue.Empty:
|
except queue.Empty:
|
||||||
break
|
break
|
||||||
notes.append({"level": note.level, "message": note.message})
|
notes.append({"level": note.level, "message": note.message})
|
||||||
@@ -461,10 +485,12 @@ def add_relay(data: dict, authorization: str | None = Header(None)) -> dict[str,
|
|||||||
_check_token(authorization)
|
_check_token(authorization)
|
||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
url = data.get("url")
|
url = data.get("url")
|
||||||
|
|
||||||
if not url:
|
if not url:
|
||||||
raise HTTPException(status_code=400, detail="Missing url")
|
raise HTTPException(status_code=400, detail="Missing url")
|
||||||
cfg = _pm.config_manager.load_config(require_pin=False)
|
cfg = _pm.config_manager.load_config(require_pin=False)
|
||||||
relays = cfg.get("relays", [])
|
relays = cfg.get("relays", [])
|
||||||
|
|
||||||
if url in relays:
|
if url in relays:
|
||||||
raise HTTPException(status_code=400, detail="Relay already present")
|
raise HTTPException(status_code=400, detail="Relay already present")
|
||||||
relays.append(url)
|
relays.append(url)
|
||||||
@@ -480,6 +506,7 @@ def remove_relay(idx: int, authorization: str | None = Header(None)) -> dict[str
|
|||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
cfg = _pm.config_manager.load_config(require_pin=False)
|
cfg = _pm.config_manager.load_config(require_pin=False)
|
||||||
relays = cfg.get("relays", [])
|
relays = cfg.get("relays", [])
|
||||||
|
|
||||||
if not (1 <= idx <= len(relays)):
|
if not (1 <= idx <= len(relays)):
|
||||||
raise HTTPException(status_code=400, detail="Invalid index")
|
raise HTTPException(status_code=400, detail="Invalid index")
|
||||||
if len(relays) == 1:
|
if len(relays) == 1:
|
||||||
@@ -546,9 +573,11 @@ async def import_vault(
|
|||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
|
|
||||||
ctype = request.headers.get("content-type", "")
|
ctype = request.headers.get("content-type", "")
|
||||||
|
|
||||||
if ctype.startswith("multipart/form-data"):
|
if ctype.startswith("multipart/form-data"):
|
||||||
form = await request.form()
|
form = await request.form()
|
||||||
file = form.get("file")
|
file = form.get("file")
|
||||||
|
|
||||||
if file is None:
|
if file is None:
|
||||||
raise HTTPException(status_code=400, detail="Missing file")
|
raise HTTPException(status_code=400, detail="Missing file")
|
||||||
data = await file.read()
|
data = await file.read()
|
||||||
@@ -562,6 +591,7 @@ async def import_vault(
|
|||||||
else:
|
else:
|
||||||
body = await request.json()
|
body = await request.json()
|
||||||
path = body.get("path")
|
path = body.get("path")
|
||||||
|
|
||||||
if not path:
|
if not path:
|
||||||
raise HTTPException(status_code=400, detail="Missing file or path")
|
raise HTTPException(status_code=400, detail="Missing file or path")
|
||||||
_pm.handle_import_database(Path(path))
|
_pm.handle_import_database(Path(path))
|
||||||
@@ -581,9 +611,11 @@ def backup_parent_seed(
|
|||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
|
|
||||||
if not data.get("confirm"):
|
if not data.get("confirm"):
|
||||||
|
|
||||||
raise HTTPException(status_code=400, detail="Confirmation required")
|
raise HTTPException(status_code=400, detail="Confirmation required")
|
||||||
|
|
||||||
path_str = data.get("path")
|
path_str = data.get("path")
|
||||||
|
|
||||||
if not path_str:
|
if not path_str:
|
||||||
raise HTTPException(status_code=400, detail="Missing path")
|
raise HTTPException(status_code=400, detail="Missing path")
|
||||||
path = Path(path_str)
|
path = Path(path_str)
|
||||||
@@ -600,6 +632,7 @@ def change_password(
|
|||||||
_check_token(authorization)
|
_check_token(authorization)
|
||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
_pm.change_password(data.get("old", ""), data.get("new", ""))
|
_pm.change_password(data.get("old", ""), data.get("new", ""))
|
||||||
|
|
||||||
return {"status": "ok"}
|
return {"status": "ok"}
|
||||||
|
|
||||||
|
|
||||||
@@ -611,6 +644,7 @@ def generate_password(
|
|||||||
_check_token(authorization)
|
_check_token(authorization)
|
||||||
assert _pm is not None
|
assert _pm is not None
|
||||||
length = int(data.get("length", 12))
|
length = int(data.get("length", 12))
|
||||||
|
|
||||||
policy_keys = [
|
policy_keys = [
|
||||||
"include_special_chars",
|
"include_special_chars",
|
||||||
"allowed_special_chars",
|
"allowed_special_chars",
|
||||||
@@ -622,6 +656,7 @@ def generate_password(
|
|||||||
"min_special",
|
"min_special",
|
||||||
]
|
]
|
||||||
kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None}
|
kwargs = {k: data.get(k) for k in policy_keys if data.get(k) is not None}
|
||||||
|
|
||||||
util = UtilityService(_pm)
|
util = UtilityService(_pm)
|
||||||
password = util.generate_password(length, **kwargs)
|
password = util.generate_password(length, **kwargs)
|
||||||
return {"password": password}
|
return {"password": password}
|
||||||
@@ -640,4 +675,5 @@ def lock_vault(authorization: str | None = Header(None)) -> dict[str, str]:
|
|||||||
async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]:
|
async def shutdown_server(authorization: str | None = Header(None)) -> dict[str, str]:
|
||||||
_check_token(authorization)
|
_check_token(authorization)
|
||||||
asyncio.get_event_loop().call_soon(sys.exit, 0)
|
asyncio.get_event_loop().call_soon(sys.exit, 0)
|
||||||
|
|
||||||
return {"status": "shutting down"}
|
return {"status": "shutting down"}
|
||||||
|
42
src/tests/test_api_rate_limit.py
Normal file
42
src/tests/test_api_rate_limit.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import importlib
|
||||||
|
from pathlib import Path
|
||||||
|
from types import SimpleNamespace
|
||||||
|
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
import sys
|
||||||
|
|
||||||
|
sys.path.append(str(Path(__file__).resolve().parents[1]))
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_exceeded(monkeypatch):
|
||||||
|
monkeypatch.setenv("SEEDPASS_RATE_LIMIT", "2")
|
||||||
|
monkeypatch.setenv("SEEDPASS_RATE_WINDOW", "60")
|
||||||
|
import seedpass.api as api
|
||||||
|
|
||||||
|
importlib.reload(api)
|
||||||
|
|
||||||
|
dummy = SimpleNamespace(
|
||||||
|
entry_manager=SimpleNamespace(
|
||||||
|
search_entries=lambda q: [
|
||||||
|
(1, "Site", "user", "url", False, SimpleNamespace(value="password"))
|
||||||
|
]
|
||||||
|
),
|
||||||
|
config_manager=SimpleNamespace(load_config=lambda require_pin=False: {}),
|
||||||
|
fingerprint_manager=SimpleNamespace(list_fingerprints=lambda: []),
|
||||||
|
nostr_client=SimpleNamespace(
|
||||||
|
key_manager=SimpleNamespace(get_npub=lambda: "np")
|
||||||
|
),
|
||||||
|
verify_password=lambda pw: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(api, "PasswordManager", lambda: dummy)
|
||||||
|
token = api.start_server()
|
||||||
|
client = TestClient(api.app)
|
||||||
|
headers = {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
for _ in range(2):
|
||||||
|
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||||
|
assert res.status_code == 200
|
||||||
|
|
||||||
|
res = client.get("/api/v1/entry", params={"query": "s"}, headers=headers)
|
||||||
|
assert res.status_code == 429
|
Reference in New Issue
Block a user