mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-08 15:28:44 +00:00
Defer nostr client connections
This commit is contained in:
@@ -123,12 +123,22 @@ class NostrClient:
|
||||
signer = NostrSigner.keys(self.keys)
|
||||
self.client = Client(signer)
|
||||
|
||||
self.initialize_client_pool()
|
||||
self._connected = False
|
||||
|
||||
def connect(self) -> None:
|
||||
"""Connect the client to all configured relays."""
|
||||
if not self._connected:
|
||||
self.initialize_client_pool()
|
||||
|
||||
def initialize_client_pool(self) -> None:
|
||||
"""Add relays to the client and connect."""
|
||||
asyncio.run(self._initialize_client_pool())
|
||||
|
||||
async def _connect_async(self) -> None:
|
||||
"""Ensure the client is connected within an async context."""
|
||||
if not self._connected:
|
||||
await self._initialize_client_pool()
|
||||
|
||||
async def _initialize_client_pool(self) -> None:
|
||||
if hasattr(self.client, "add_relays"):
|
||||
await self.client.add_relays(self.relays)
|
||||
@@ -136,6 +146,7 @@ class NostrClient:
|
||||
for relay in self.relays:
|
||||
await self.client.add_relay(relay)
|
||||
await self.client.connect()
|
||||
self._connected = True
|
||||
logger.info(f"NostrClient connected to relays: {self.relays}")
|
||||
|
||||
async def _ping_relay(self, relay: str, timeout: float) -> bool:
|
||||
@@ -190,6 +201,7 @@ class NostrClient:
|
||||
If provided, include an ``alt`` tag so uploads can be
|
||||
associated with a specific event like a password change.
|
||||
"""
|
||||
self.connect()
|
||||
self.last_error = None
|
||||
try:
|
||||
content = base64.b64encode(encrypted_json).decode("utf-8")
|
||||
@@ -221,9 +233,11 @@ class NostrClient:
|
||||
|
||||
def publish_event(self, event):
|
||||
"""Publish a prepared event to the configured relays."""
|
||||
self.connect()
|
||||
return asyncio.run(self._publish_event(event))
|
||||
|
||||
async def _publish_event(self, event):
|
||||
await self._connect_async()
|
||||
return await self.client.send_event(event)
|
||||
|
||||
def update_relays(self, new_relays: List[str]) -> None:
|
||||
@@ -232,12 +246,13 @@ class NostrClient:
|
||||
self.relays = new_relays
|
||||
signer = NostrSigner.keys(self.keys)
|
||||
self.client = Client(signer)
|
||||
self.initialize_client_pool()
|
||||
self._connected = False
|
||||
|
||||
def retrieve_json_from_nostr_sync(
|
||||
self, retries: int = 0, delay: float = 2.0
|
||||
) -> Optional[bytes]:
|
||||
"""Retrieve the latest Kind 1 event from the author with optional retries."""
|
||||
self.connect()
|
||||
self.last_error = None
|
||||
attempt = 0
|
||||
while True:
|
||||
@@ -255,6 +270,7 @@ class NostrClient:
|
||||
return None
|
||||
|
||||
async def _retrieve_json_from_nostr(self) -> Optional[bytes]:
|
||||
await self._connect_async()
|
||||
# Filter for the latest text note (Kind 1) from our public key
|
||||
pubkey = self.keys.public_key()
|
||||
f = Filter().author(pubkey).kind(Kind.from_std(KindStandard.TEXT_NOTE)).limit(1)
|
||||
@@ -288,6 +304,7 @@ class NostrClient:
|
||||
Maximum chunk size in bytes. Defaults to 50 kB.
|
||||
"""
|
||||
|
||||
await self._connect_async()
|
||||
manifest, chunks = prepare_snapshot(encrypted_bytes, limit)
|
||||
for meta, chunk in zip(manifest.chunks, chunks):
|
||||
content = base64.b64encode(chunk).decode("utf-8")
|
||||
@@ -320,6 +337,8 @@ class NostrClient:
|
||||
async def fetch_latest_snapshot(self) -> Tuple[Manifest, list[bytes]] | None:
|
||||
"""Retrieve the latest manifest and all snapshot chunks."""
|
||||
|
||||
await self._connect_async()
|
||||
|
||||
pubkey = self.keys.public_key()
|
||||
f = Filter().author(pubkey).kind(Kind(KIND_MANIFEST)).limit(1)
|
||||
timeout = timedelta(seconds=10)
|
||||
@@ -358,6 +377,8 @@ class NostrClient:
|
||||
async def publish_delta(self, delta_bytes: bytes, manifest_id: str) -> str:
|
||||
"""Publish a delta event referencing a manifest."""
|
||||
|
||||
await self._connect_async()
|
||||
|
||||
content = base64.b64encode(delta_bytes).decode("utf-8")
|
||||
tag = Tag.event(EventId.parse(manifest_id))
|
||||
builder = EventBuilder(Kind(KIND_DELTA), content).tags([tag])
|
||||
@@ -372,6 +393,8 @@ class NostrClient:
|
||||
async def fetch_deltas_since(self, version: int) -> list[bytes]:
|
||||
"""Retrieve delta events newer than the given version."""
|
||||
|
||||
await self._connect_async()
|
||||
|
||||
pubkey = self.keys.public_key()
|
||||
f = (
|
||||
Filter()
|
||||
@@ -409,6 +432,7 @@ class NostrClient:
|
||||
"""Disconnects the client from all relays."""
|
||||
try:
|
||||
asyncio.run(self.client.disconnect())
|
||||
self._connected = False
|
||||
logger.info("NostrClient disconnected from relays.")
|
||||
except Exception as e:
|
||||
logger.error("Error during NostrClient shutdown: %s", e)
|
||||
|
@@ -88,6 +88,7 @@ def _setup_client(tmpdir, fake_cls):
|
||||
def test_initialize_client_pool_add_relays_used(tmp_path):
|
||||
client = _setup_client(tmp_path, FakeAddRelaysClient)
|
||||
fc = client.client
|
||||
client.connect()
|
||||
assert fc.added == [client.relays]
|
||||
assert fc.connected is True
|
||||
|
||||
@@ -95,6 +96,7 @@ def test_initialize_client_pool_add_relays_used(tmp_path):
|
||||
def test_initialize_client_pool_add_relay_fallback(tmp_path):
|
||||
client = _setup_client(tmp_path, FakeAddRelayClient)
|
||||
fc = client.client
|
||||
client.connect()
|
||||
assert fc.added == client.relays
|
||||
assert fc.connected is True
|
||||
|
||||
|
Reference in New Issue
Block a user