diff --git a/src/constants.py b/src/constants.py index 6a86868..dfcd0d1 100644 --- a/src/constants.py +++ b/src/constants.py @@ -11,6 +11,7 @@ logger = logging.getLogger(__name__) # ----------------------------------- MAX_RETRIES = 3 # Maximum number of retries for relay connections RETRY_DELAY = 5 # Seconds to wait before retrying a failed connection +MIN_HEALTHY_RELAYS = 2 # Minimum relays that should return data on startup # ----------------------------------- # Application Directory and Paths diff --git a/src/nostr/client.py b/src/nostr/client.py index 20fd2f3..2ff7019 100644 --- a/src/nostr/client.py +++ b/src/nostr/client.py @@ -8,6 +8,7 @@ from typing import List, Optional, Tuple import hashlib import asyncio import gzip +import websockets # Imports from the nostr-sdk library from nostr_sdk import ( @@ -137,6 +138,42 @@ class NostrClient: await self.client.connect() logger.info(f"NostrClient connected to relays: {self.relays}") + async def _ping_relay(self, relay: str, timeout: float) -> bool: + """Attempt to retrieve the latest event from a single relay.""" + sub_id = "seedpass-health" + pubkey = self.keys.public_key().to_hex() + req = json.dumps( + ["REQ", sub_id, {"kinds": [1], "authors": [pubkey], "limit": 1}] + ) + try: + async with websockets.connect( + relay, open_timeout=timeout, close_timeout=timeout + ) as ws: + await ws.send(req) + while True: + msg = await asyncio.wait_for(ws.recv(), timeout=timeout) + data = json.loads(msg) + if data[0] == "EVENT": + return True + if data[0] == "EOSE": + return False + except Exception: + return False + + async def _check_relay_health(self, min_relays: int, timeout: float) -> int: + tasks = [self._ping_relay(r, timeout) for r in self.relays] + results = await asyncio.gather(*tasks, return_exceptions=True) + healthy = sum(1 for r in results if r is True) + if healthy < min_relays: + logger.warning( + "Only %s relays responded with data; consider adding more.", healthy + ) + return healthy + + def check_relay_health(self, min_relays: int = 2, timeout: float = 5.0) -> int: + """Ping relays and return the count of those providing data.""" + return asyncio.run(self._check_relay_health(min_relays, timeout)) + def publish_json_to_nostr( self, encrypted_json: bytes, diff --git a/src/password_manager/manager.py b/src/password_manager/manager.py index 0115289..76295ff 100644 --- a/src/password_manager/manager.py +++ b/src/password_manager/manager.py @@ -40,6 +40,7 @@ from utils.password_prompt import ( prompt_existing_password, confirm_action, ) +from constants import MIN_HEALTHY_RELAYS from constants import ( APP_DIR, @@ -763,6 +764,17 @@ class PasswordManager: parent_seed=getattr(self, "parent_seed", None), ) + if hasattr(self.nostr_client, "check_relay_health"): + healthy = self.nostr_client.check_relay_health(MIN_HEALTHY_RELAYS) + if healthy < MIN_HEALTHY_RELAYS: + print( + colored( + f"Only {healthy} relay(s) responded with your latest event." + " Consider adding more relays via Settings.", + "yellow", + ) + ) + logger.debug("Managers re-initialized for the new fingerprint.") except Exception as e: diff --git a/src/tests/test_nostr_client.py b/src/tests/test_nostr_client.py index fe58737..8a849b8 100644 --- a/src/tests/test_nostr_client.py +++ b/src/tests/test_nostr_client.py @@ -75,3 +75,19 @@ def test_initialize_client_pool_add_relay_fallback(tmp_path): fc = client.client assert fc.added == client.relays assert fc.connected is True + + +def test_check_relay_health_runs_async(tmp_path, monkeypatch): + client = _setup_client(tmp_path, FakeAddRelayClient) + + recorded = {} + + async def fake_check(min_relays, timeout): + recorded["args"] = (min_relays, timeout) + return 1 + + monkeypatch.setattr(client, "_check_relay_health", fake_check) + result = client.check_relay_health(3, timeout=2) + + assert result == 1 + assert recorded["args"] == (3, 2)