Merge pull request #808 from PR0M3TH3AN/codex/convert-functions-to-async-and-update-calls

Convert Nostr client operations to async
This commit is contained in:
thePR0M3TH3AN
2025-08-10 21:06:21 -04:00
committed by GitHub
3 changed files with 42 additions and 24 deletions

View File

@@ -2,7 +2,6 @@ import asyncio
import base64 import base64
import json import json
import logging import logging
import time
from datetime import timedelta from datetime import timedelta
from typing import List, Optional from typing import List, Optional
@@ -23,12 +22,12 @@ DEFAULT_RELAYS = [
class ConnectionHandler: class ConnectionHandler:
"""Mixin providing relay connection and retry logic.""" """Mixin providing relay connection and retry logic."""
def connect(self) -> None: async def connect(self) -> None:
"""Connect the client to all configured relays.""" """Connect the client to all configured relays."""
if self.offline_mode or not self.relays: if self.offline_mode or not self.relays:
return return
if not getattr(self, "_connected", False): if not getattr(self, "_connected", False):
self.initialize_client_pool() await self._initialize_client_pool()
def initialize_client_pool(self) -> None: def initialize_client_pool(self) -> None:
"""Add relays to the client and connect.""" """Add relays to the client and connect."""
@@ -107,7 +106,7 @@ class ConnectionHandler:
return 0 return 0
return asyncio.run(self._check_relay_health(min_relays, timeout)) return asyncio.run(self._check_relay_health(min_relays, timeout))
def publish_json_to_nostr( async def publish_json_to_nostr(
self, self,
encrypted_json: bytes, encrypted_json: bytes,
to_pubkey: str | None = None, to_pubkey: str | None = None,
@@ -116,7 +115,7 @@ class ConnectionHandler:
"""Build and publish a Kind 1 text note or direct message.""" """Build and publish a Kind 1 text note or direct message."""
if self.offline_mode or not self.relays: if self.offline_mode or not self.relays:
return None return None
self.connect() await self.connect()
self.last_error = None self.last_error = None
try: try:
content = base64.b64encode(encrypted_json).decode("utf-8") content = base64.b64encode(encrypted_json).decode("utf-8")
@@ -131,7 +130,7 @@ class ConnectionHandler:
if alt_summary: if alt_summary:
builder = builder.tags([nostr_client.Tag.alt(alt_summary)]) builder = builder.tags([nostr_client.Tag.alt(alt_summary)])
event = builder.build(self.keys.public_key()).sign_with_keys(self.keys) event = builder.build(self.keys.public_key()).sign_with_keys(self.keys)
event_output = self.publish_event(event) event_output = await self.publish_event(event)
event_id_hex = ( event_id_hex = (
event_output.id.to_hex() event_output.id.to_hex()
@@ -146,17 +145,11 @@ class ConnectionHandler:
logger.error("Failed to publish JSON to Nostr: %s", e) logger.error("Failed to publish JSON to Nostr: %s", e)
return None return None
def publish_event(self, event): async def publish_event(self, event):
"""Publish a prepared event to the configured relays.""" """Publish a prepared event to the configured relays."""
if self.offline_mode or not self.relays: if self.offline_mode or not self.relays:
return None return None
self.connect() await self.connect()
return asyncio.run(self._publish_event(event))
async def _publish_event(self, event):
if self.offline_mode or not self.relays:
return None
await self._connect_async()
return await self.client.send_event(event) return await self.client.send_event(event)
def update_relays(self, new_relays: List[str]) -> None: def update_relays(self, new_relays: List[str]) -> None:
@@ -168,7 +161,7 @@ class ConnectionHandler:
self._connected = False self._connected = False
self.initialize_client_pool() self.initialize_client_pool()
def retrieve_json_from_nostr_sync( async def retrieve_json_from_nostr(
self, retries: int | None = None, delay: float | None = None self, retries: int | None = None, delay: float | None = None
) -> Optional[bytes]: ) -> Optional[bytes]:
"""Retrieve the latest Kind 1 event from the author with optional retries.""" """Retrieve the latest Kind 1 event from the author with optional retries."""
@@ -190,11 +183,11 @@ class ConnectionHandler:
retries = int(cfg.get("nostr_max_retries", MAX_RETRIES)) retries = int(cfg.get("nostr_max_retries", MAX_RETRIES))
delay = float(cfg.get("nostr_retry_delay", RETRY_DELAY)) delay = float(cfg.get("nostr_retry_delay", RETRY_DELAY))
self.connect() await self.connect()
self.last_error = None self.last_error = None
for attempt in range(retries): for attempt in range(retries):
try: try:
result = asyncio.run(self._retrieve_json_from_nostr()) result = await self._retrieve_json_from_nostr()
if result is not None: if result is not None:
return result return result
except Exception as e: except Exception as e:
@@ -202,7 +195,7 @@ class ConnectionHandler:
logger.error("Failed to retrieve events from Nostr: %s", e) logger.error("Failed to retrieve events from Nostr: %s", e)
if attempt < retries - 1: if attempt < retries - 1:
sleep_time = delay * (2**attempt) sleep_time = delay * (2**attempt)
time.sleep(sleep_time) await asyncio.sleep(sleep_time)
return None return None
async def _retrieve_json_from_nostr(self) -> Optional[bytes]: async def _retrieve_json_from_nostr(self) -> Optional[bytes]:

View File

@@ -4355,7 +4355,7 @@ class PasswordManager:
manifest, event_id = pub_snap(encrypted) manifest, event_id = pub_snap(encrypted)
else: else:
# Fallback for tests using simplified stubs # Fallback for tests using simplified stubs
event_id = self.nostr_client.publish_json_to_nostr(encrypted) event_id = await self.nostr_client.publish_json_to_nostr(encrypted)
self.is_dirty = False self.is_dirty = False
if event_id is None: if event_id is None:
return None return None

View File

@@ -90,7 +90,7 @@ def _setup_client(tmpdir, fake_cls):
def test_initialize_client_pool_add_relays_used(tmp_path): def test_initialize_client_pool_add_relays_used(tmp_path):
client = _setup_client(tmp_path, FakeAddRelaysClient) client = _setup_client(tmp_path, FakeAddRelaysClient)
fc = client.client fc = client.client
client.connect() asyncio.run(client.connect())
assert [[str(r) for r in relays] for relays in fc.added] == [client.relays] assert [[str(r) for r in relays] for relays in fc.added] == [client.relays]
assert fc.connected is True assert fc.connected is True
@@ -98,7 +98,7 @@ def test_initialize_client_pool_add_relays_used(tmp_path):
def test_initialize_client_pool_add_relay_fallback(tmp_path): def test_initialize_client_pool_add_relay_fallback(tmp_path):
client = _setup_client(tmp_path, FakeAddRelayClient) client = _setup_client(tmp_path, FakeAddRelayClient)
fc = client.client fc = client.client
client.connect() asyncio.run(client.connect())
assert [str(r) for r in fc.added] == client.relays assert [str(r) for r in fc.added] == client.relays
assert fc.connected is True assert fc.connected is True
@@ -166,17 +166,42 @@ def test_retrieve_json_sync_backoff(tmp_path, monkeypatch):
sleeps: list[float] = [] sleeps: list[float] = []
def fake_sleep(d): async def fake_sleep(d, *_a, **_k):
sleeps.append(d) sleeps.append(d)
monkeypatch.setattr(nostr_client.time, "sleep", fake_sleep) monkeypatch.setattr(asyncio, "sleep", fake_sleep)
async def fake_async(self): async def fake_async(self):
return None return None
monkeypatch.setattr(NostrClient, "_retrieve_json_from_nostr", fake_async) monkeypatch.setattr(NostrClient, "_retrieve_json_from_nostr", fake_async)
result = client.retrieve_json_from_nostr_sync() result = asyncio.run(client.retrieve_json_from_nostr())
assert result is None assert result is None
assert sleeps == [1, 2] assert sleeps == [1, 2]
def test_client_methods_run_in_event_loop(tmp_path, monkeypatch):
client = _setup_client(tmp_path, FakeAddRelayClient)
async def fake_init(self):
self._connected = True
async def fake_send_event(_event):
return None
async def fake_retrieve(self):
return b"data"
monkeypatch.setattr(NostrClient, "_initialize_client_pool", fake_init)
monkeypatch.setattr(client.client, "send_event", fake_send_event, raising=False)
monkeypatch.setattr(NostrClient, "_retrieve_json_from_nostr", fake_retrieve)
async def run_all():
await client.connect()
await client.publish_event("e")
data = await client.retrieve_json_from_nostr()
assert data == b"data"
asyncio.run(run_all())