Make Nostr client async

This commit is contained in:
thePR0M3TH3AN
2025-08-10 20:59:04 -04:00
parent b9525db9ae
commit b0a2f17cc8
3 changed files with 42 additions and 24 deletions

View File

@@ -2,7 +2,6 @@ import asyncio
import base64
import json
import logging
import time
from datetime import timedelta
from typing import List, Optional
@@ -23,12 +22,12 @@ DEFAULT_RELAYS = [
class ConnectionHandler:
"""Mixin providing relay connection and retry logic."""
def connect(self) -> None:
async def connect(self) -> None:
"""Connect the client to all configured relays."""
if self.offline_mode or not self.relays:
return
if not getattr(self, "_connected", False):
self.initialize_client_pool()
await self._initialize_client_pool()
def initialize_client_pool(self) -> None:
"""Add relays to the client and connect."""
@@ -107,7 +106,7 @@ class ConnectionHandler:
return 0
return asyncio.run(self._check_relay_health(min_relays, timeout))
def publish_json_to_nostr(
async def publish_json_to_nostr(
self,
encrypted_json: bytes,
to_pubkey: str | None = None,
@@ -116,7 +115,7 @@ class ConnectionHandler:
"""Build and publish a Kind 1 text note or direct message."""
if self.offline_mode or not self.relays:
return None
self.connect()
await self.connect()
self.last_error = None
try:
content = base64.b64encode(encrypted_json).decode("utf-8")
@@ -131,7 +130,7 @@ class ConnectionHandler:
if alt_summary:
builder = builder.tags([nostr_client.Tag.alt(alt_summary)])
event = builder.build(self.keys.public_key()).sign_with_keys(self.keys)
event_output = self.publish_event(event)
event_output = await self.publish_event(event)
event_id_hex = (
event_output.id.to_hex()
@@ -146,17 +145,11 @@ class ConnectionHandler:
logger.error("Failed to publish JSON to Nostr: %s", e)
return None
def publish_event(self, event):
async def publish_event(self, event):
"""Publish a prepared event to the configured relays."""
if self.offline_mode or not self.relays:
return None
self.connect()
return asyncio.run(self._publish_event(event))
async def _publish_event(self, event):
if self.offline_mode or not self.relays:
return None
await self._connect_async()
await self.connect()
return await self.client.send_event(event)
def update_relays(self, new_relays: List[str]) -> None:
@@ -168,7 +161,7 @@ class ConnectionHandler:
self._connected = False
self.initialize_client_pool()
def retrieve_json_from_nostr_sync(
async def retrieve_json_from_nostr(
self, retries: int | None = None, delay: float | None = None
) -> Optional[bytes]:
"""Retrieve the latest Kind 1 event from the author with optional retries."""
@@ -190,11 +183,11 @@ class ConnectionHandler:
retries = int(cfg.get("nostr_max_retries", MAX_RETRIES))
delay = float(cfg.get("nostr_retry_delay", RETRY_DELAY))
self.connect()
await self.connect()
self.last_error = None
for attempt in range(retries):
try:
result = asyncio.run(self._retrieve_json_from_nostr())
result = await self._retrieve_json_from_nostr()
if result is not None:
return result
except Exception as e:
@@ -202,7 +195,7 @@ class ConnectionHandler:
logger.error("Failed to retrieve events from Nostr: %s", e)
if attempt < retries - 1:
sleep_time = delay * (2**attempt)
time.sleep(sleep_time)
await asyncio.sleep(sleep_time)
return None
async def _retrieve_json_from_nostr(self) -> Optional[bytes]:

View File

@@ -4355,7 +4355,7 @@ class PasswordManager:
manifest, event_id = pub_snap(encrypted)
else:
# Fallback for tests using simplified stubs
event_id = self.nostr_client.publish_json_to_nostr(encrypted)
event_id = await self.nostr_client.publish_json_to_nostr(encrypted)
self.is_dirty = False
if event_id is None:
return None

View File

@@ -90,7 +90,7 @@ def _setup_client(tmpdir, fake_cls):
def test_initialize_client_pool_add_relays_used(tmp_path):
client = _setup_client(tmp_path, FakeAddRelaysClient)
fc = client.client
client.connect()
asyncio.run(client.connect())
assert [[str(r) for r in relays] for relays in fc.added] == [client.relays]
assert fc.connected is True
@@ -98,7 +98,7 @@ def test_initialize_client_pool_add_relays_used(tmp_path):
def test_initialize_client_pool_add_relay_fallback(tmp_path):
client = _setup_client(tmp_path, FakeAddRelayClient)
fc = client.client
client.connect()
asyncio.run(client.connect())
assert [str(r) for r in fc.added] == client.relays
assert fc.connected is True
@@ -166,17 +166,42 @@ def test_retrieve_json_sync_backoff(tmp_path, monkeypatch):
sleeps: list[float] = []
def fake_sleep(d):
async def fake_sleep(d, *_a, **_k):
sleeps.append(d)
monkeypatch.setattr(nostr_client.time, "sleep", fake_sleep)
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
async def fake_async(self):
return None
monkeypatch.setattr(NostrClient, "_retrieve_json_from_nostr", fake_async)
result = client.retrieve_json_from_nostr_sync()
result = asyncio.run(client.retrieve_json_from_nostr())
assert result is None
assert sleeps == [1, 2]
def test_client_methods_run_in_event_loop(tmp_path, monkeypatch):
client = _setup_client(tmp_path, FakeAddRelayClient)
async def fake_init(self):
self._connected = True
async def fake_send_event(_event):
return None
async def fake_retrieve(self):
return b"data"
monkeypatch.setattr(NostrClient, "_initialize_client_pool", fake_init)
monkeypatch.setattr(client.client, "send_event", fake_send_event, raising=False)
monkeypatch.setattr(NostrClient, "_retrieve_json_from_nostr", fake_retrieve)
async def run_all():
await client.connect()
await client.publish_event("e")
data = await client.retrieve_json_from_nostr()
assert data == b"data"
asyncio.run(run_all())