From 1e270c9ab129a5c9e5e2114e9c7f996ed4101b16 Mon Sep 17 00:00:00 2001 From: thePR0M3TH3AN <53631862+PR0M3TH3AN@users.noreply.github.com> Date: Fri, 11 Jul 2025 22:50:08 -0400 Subject: [PATCH] Defer nostr client connections --- src/nostr/client.py | 28 ++++++++++++++++++++++++++-- src/tests/test_nostr_client.py | 2 ++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/src/nostr/client.py b/src/nostr/client.py index 320ba72..4ebe15f 100644 --- a/src/nostr/client.py +++ b/src/nostr/client.py @@ -123,12 +123,22 @@ class NostrClient: signer = NostrSigner.keys(self.keys) self.client = Client(signer) - self.initialize_client_pool() + self._connected = False + + def connect(self) -> None: + """Connect the client to all configured relays.""" + if not self._connected: + self.initialize_client_pool() def initialize_client_pool(self) -> None: """Add relays to the client and connect.""" asyncio.run(self._initialize_client_pool()) + async def _connect_async(self) -> None: + """Ensure the client is connected within an async context.""" + if not self._connected: + await self._initialize_client_pool() + async def _initialize_client_pool(self) -> None: if hasattr(self.client, "add_relays"): await self.client.add_relays(self.relays) @@ -136,6 +146,7 @@ class NostrClient: for relay in self.relays: await self.client.add_relay(relay) await self.client.connect() + self._connected = True logger.info(f"NostrClient connected to relays: {self.relays}") async def _ping_relay(self, relay: str, timeout: float) -> bool: @@ -190,6 +201,7 @@ class NostrClient: If provided, include an ``alt`` tag so uploads can be associated with a specific event like a password change. """ + self.connect() self.last_error = None try: content = base64.b64encode(encrypted_json).decode("utf-8") @@ -221,9 +233,11 @@ class NostrClient: def publish_event(self, event): """Publish a prepared event to the configured relays.""" + self.connect() return asyncio.run(self._publish_event(event)) async def _publish_event(self, event): + await self._connect_async() return await self.client.send_event(event) def update_relays(self, new_relays: List[str]) -> None: @@ -232,12 +246,13 @@ class NostrClient: self.relays = new_relays signer = NostrSigner.keys(self.keys) self.client = Client(signer) - self.initialize_client_pool() + self._connected = False def retrieve_json_from_nostr_sync( self, retries: int = 0, delay: float = 2.0 ) -> Optional[bytes]: """Retrieve the latest Kind 1 event from the author with optional retries.""" + self.connect() self.last_error = None attempt = 0 while True: @@ -255,6 +270,7 @@ class NostrClient: return None async def _retrieve_json_from_nostr(self) -> Optional[bytes]: + await self._connect_async() # Filter for the latest text note (Kind 1) from our public key pubkey = self.keys.public_key() f = Filter().author(pubkey).kind(Kind.from_std(KindStandard.TEXT_NOTE)).limit(1) @@ -288,6 +304,7 @@ class NostrClient: Maximum chunk size in bytes. Defaults to 50 kB. """ + await self._connect_async() manifest, chunks = prepare_snapshot(encrypted_bytes, limit) for meta, chunk in zip(manifest.chunks, chunks): content = base64.b64encode(chunk).decode("utf-8") @@ -320,6 +337,8 @@ class NostrClient: async def fetch_latest_snapshot(self) -> Tuple[Manifest, list[bytes]] | None: """Retrieve the latest manifest and all snapshot chunks.""" + await self._connect_async() + pubkey = self.keys.public_key() f = Filter().author(pubkey).kind(Kind(KIND_MANIFEST)).limit(1) timeout = timedelta(seconds=10) @@ -358,6 +377,8 @@ class NostrClient: async def publish_delta(self, delta_bytes: bytes, manifest_id: str) -> str: """Publish a delta event referencing a manifest.""" + await self._connect_async() + content = base64.b64encode(delta_bytes).decode("utf-8") tag = Tag.event(EventId.parse(manifest_id)) builder = EventBuilder(Kind(KIND_DELTA), content).tags([tag]) @@ -372,6 +393,8 @@ class NostrClient: async def fetch_deltas_since(self, version: int) -> list[bytes]: """Retrieve delta events newer than the given version.""" + await self._connect_async() + pubkey = self.keys.public_key() f = ( Filter() @@ -409,6 +432,7 @@ class NostrClient: """Disconnects the client from all relays.""" try: asyncio.run(self.client.disconnect()) + self._connected = False logger.info("NostrClient disconnected from relays.") except Exception as e: logger.error("Error during NostrClient shutdown: %s", e) diff --git a/src/tests/test_nostr_client.py b/src/tests/test_nostr_client.py index 310ab4b..9c76a73 100644 --- a/src/tests/test_nostr_client.py +++ b/src/tests/test_nostr_client.py @@ -88,6 +88,7 @@ def _setup_client(tmpdir, fake_cls): def test_initialize_client_pool_add_relays_used(tmp_path): client = _setup_client(tmp_path, FakeAddRelaysClient) fc = client.client + client.connect() assert fc.added == [client.relays] assert fc.connected is True @@ -95,6 +96,7 @@ def test_initialize_client_pool_add_relays_used(tmp_path): def test_initialize_client_pool_add_relay_fallback(tmp_path): client = _setup_client(tmp_path, FakeAddRelayClient) fc = client.client + client.connect() assert fc.added == client.relays assert fc.connected is True