Defer nostr client connections

This commit is contained in:
thePR0M3TH3AN
2025-07-11 22:50:08 -04:00
parent babb4d49f0
commit 1e270c9ab1
2 changed files with 28 additions and 2 deletions

View File

@@ -123,12 +123,22 @@ class NostrClient:
signer = NostrSigner.keys(self.keys)
self.client = Client(signer)
self._connected = False
def connect(self) -> None:
"""Connect the client to all configured relays."""
if not self._connected:
self.initialize_client_pool()
def initialize_client_pool(self) -> None:
"""Add relays to the client and connect."""
asyncio.run(self._initialize_client_pool())
async def _connect_async(self) -> None:
"""Ensure the client is connected within an async context."""
if not self._connected:
await self._initialize_client_pool()
async def _initialize_client_pool(self) -> None:
if hasattr(self.client, "add_relays"):
await self.client.add_relays(self.relays)
@@ -136,6 +146,7 @@ class NostrClient:
for relay in self.relays:
await self.client.add_relay(relay)
await self.client.connect()
self._connected = True
logger.info(f"NostrClient connected to relays: {self.relays}")
async def _ping_relay(self, relay: str, timeout: float) -> bool:
@@ -190,6 +201,7 @@ class NostrClient:
If provided, include an ``alt`` tag so uploads can be
associated with a specific event like a password change.
"""
self.connect()
self.last_error = None
try:
content = base64.b64encode(encrypted_json).decode("utf-8")
@@ -221,9 +233,11 @@ class NostrClient:
def publish_event(self, event):
"""Publish a prepared event to the configured relays."""
self.connect()
return asyncio.run(self._publish_event(event))
async def _publish_event(self, event):
await self._connect_async()
return await self.client.send_event(event)
def update_relays(self, new_relays: List[str]) -> None:
@@ -232,12 +246,13 @@ class NostrClient:
self.relays = new_relays
signer = NostrSigner.keys(self.keys)
self.client = Client(signer)
self.initialize_client_pool()
self._connected = False
def retrieve_json_from_nostr_sync(
self, retries: int = 0, delay: float = 2.0
) -> Optional[bytes]:
"""Retrieve the latest Kind 1 event from the author with optional retries."""
self.connect()
self.last_error = None
attempt = 0
while True:
@@ -255,6 +270,7 @@ class NostrClient:
return None
async def _retrieve_json_from_nostr(self) -> Optional[bytes]:
await self._connect_async()
# Filter for the latest text note (Kind 1) from our public key
pubkey = self.keys.public_key()
f = Filter().author(pubkey).kind(Kind.from_std(KindStandard.TEXT_NOTE)).limit(1)
@@ -288,6 +304,7 @@ class NostrClient:
Maximum chunk size in bytes. Defaults to 50 kB.
"""
await self._connect_async()
manifest, chunks = prepare_snapshot(encrypted_bytes, limit)
for meta, chunk in zip(manifest.chunks, chunks):
content = base64.b64encode(chunk).decode("utf-8")
@@ -320,6 +337,8 @@ class NostrClient:
async def fetch_latest_snapshot(self) -> Tuple[Manifest, list[bytes]] | None:
"""Retrieve the latest manifest and all snapshot chunks."""
await self._connect_async()
pubkey = self.keys.public_key()
f = Filter().author(pubkey).kind(Kind(KIND_MANIFEST)).limit(1)
timeout = timedelta(seconds=10)
@@ -358,6 +377,8 @@ class NostrClient:
async def publish_delta(self, delta_bytes: bytes, manifest_id: str) -> str:
"""Publish a delta event referencing a manifest."""
await self._connect_async()
content = base64.b64encode(delta_bytes).decode("utf-8")
tag = Tag.event(EventId.parse(manifest_id))
builder = EventBuilder(Kind(KIND_DELTA), content).tags([tag])
@@ -372,6 +393,8 @@ class NostrClient:
async def fetch_deltas_since(self, version: int) -> list[bytes]:
"""Retrieve delta events newer than the given version."""
await self._connect_async()
pubkey = self.keys.public_key()
f = (
Filter()
@@ -409,6 +432,7 @@ class NostrClient:
"""Disconnects the client from all relays."""
try:
asyncio.run(self.client.disconnect())
self._connected = False
logger.info("NostrClient disconnected from relays.")
except Exception as e:
logger.error("Error during NostrClient shutdown: %s", e)

View File

@@ -88,6 +88,7 @@ def _setup_client(tmpdir, fake_cls):
def test_initialize_client_pool_add_relays_used(tmp_path):
client = _setup_client(tmp_path, FakeAddRelaysClient)
fc = client.client
client.connect()
assert fc.added == [client.relays]
assert fc.connected is True
@@ -95,6 +96,7 @@ def test_initialize_client_pool_add_relays_used(tmp_path):
def test_initialize_client_pool_add_relay_fallback(tmp_path):
client = _setup_client(tmp_path, FakeAddRelayClient)
fc = client.client
client.connect()
assert fc.added == client.relays
assert fc.connected is True