Merge pull request #70 from PR0M3TH3AN/codex/switch-to-nostr-sdk-and-update-tests

Switch to nostr-sdk
This commit is contained in:
thePR0M3TH3AN
2025-06-30 22:50:21 -04:00
committed by GitHub
6 changed files with 111 additions and 92 deletions

View File

@@ -3,14 +3,21 @@ import base64
import hashlib import hashlib
import json import json
import logging import logging
import time
import uuid
from pathlib import Path from pathlib import Path
from typing import Callable, List, Optional from typing import Callable, List, Optional
from pynostr.websocket_relay_manager import WebSocketRelayManager from nostr_sdk import nostr_sdk as sdk
from pynostr.event import Event, EventKind from nostr_sdk import uniffi_set_event_loop
from pynostr.encrypted_dm import EncryptedDirectMessage
# expose key SDK classes for easier mocking in tests
ClientBuilder = sdk.ClientBuilder
EventBuilder = sdk.EventBuilder
Kind = sdk.Kind
KindStandard = sdk.KindStandard
Filter = sdk.Filter
Keys = sdk.Keys
PublicKey = sdk.PublicKey
Duration = sdk.Duration
from .key_manager import KeyManager from .key_manager import KeyManager
from password_manager.encryption import EncryptionManager from password_manager.encryption import EncryptionManager
@@ -28,7 +35,7 @@ DEFAULT_RELAYS = [
class NostrClient: class NostrClient:
"""Interact with the Nostr network using pynostr.""" """Interact with the Nostr network using nostr-sdk."""
def __init__( def __init__(
self, self,
@@ -49,47 +56,58 @@ class NostrClient:
self.initialize_client_pool() self.initialize_client_pool()
def initialize_client_pool(self) -> None: def initialize_client_pool(self) -> None:
"""Create the relay manager and connect to configured relays.""" """Create the client and connect to configured relays."""
self.client_pool = WebSocketRelayManager()
for relay in self.relays:
self.client_pool.add_relay(relay)
async def publish_event_async(self, event: Event) -> None: async def _init() -> None:
logger.debug("Publishing event %s", event.id) uniffi_set_event_loop(asyncio.get_running_loop())
self.client_pool.publish_event(event) self.client_pool = ClientBuilder().build()
for relay in self.relays:
await self.client_pool.add_relay(relay)
await self.client_pool.connect()
def publish_event(self, event: Event) -> None: asyncio.run(_init())
self.client_pool.publish_event(event)
async def publish_event_async(self, event) -> None:
logger.debug("Publishing event %s", event.id())
uniffi_set_event_loop(asyncio.get_running_loop())
await self.client_pool.send_event(event)
def publish_event(self, event) -> None:
asyncio.run(self.publish_event_async(event))
async def subscribe_async( async def subscribe_async(
self, self,
filters: List[dict], filters: List[dict],
handler: Callable[[WebSocketRelayManager, str, Event], None], handler: Callable[[object, str, object], None],
timeout: float = 2.0, timeout: float = 2.0,
) -> None: ) -> None:
sub_id = str(uuid.uuid4()) uniffi_set_event_loop(asyncio.get_running_loop())
from pynostr.filters import FiltersList for f in filters:
flt = Filter()
if "authors" in f:
flt = flt.authors([PublicKey.parse(a) for a in f["authors"]])
if "kinds" in f:
kinds = []
for k in f["kinds"]:
if k == 1:
kinds.append(sdk.Kind.from_std(sdk.KindStandard.TEXT_NOTE))
elif k == 4:
kinds.append(
sdk.Kind.from_std(sdk.KindStandard.PRIVATE_DIRECT_MESSAGE)
)
if kinds:
flt = flt.kinds(kinds)
if "limit" in f:
flt = flt.limit(f["limit"])
filter_list = FiltersList.from_json_array(filters) events = await self.client_pool.fetch_events(flt, Duration(seconds=timeout))
self.client_pool.add_subscription_on_all_relays(sub_id, filter_list) for evt in events:
self.subscriptions.add(sub_id) handler(self.client_pool, "0", evt)
end = asyncio.get_event_loop().time() + timeout
try:
while asyncio.get_event_loop().time() < end:
while self.client_pool.message_pool.has_events():
msg = self.client_pool.message_pool.get_event()
if msg.subscription_id == sub_id:
handler(self.client_pool, sub_id, msg.event)
await asyncio.sleep(0.1)
finally:
self.client_pool.close_subscription_on_all_relays(sub_id)
self.subscriptions.discard(sub_id)
def subscribe( def subscribe(
self, self,
filters: List[dict], filters: List[dict],
handler: Callable[[WebSocketRelayManager, str, Event], None], handler: Callable[[object, str, object], None],
timeout: float = 2.0, timeout: float = 2.0,
) -> None: ) -> None:
asyncio.run(self.subscribe_async(filters, handler, timeout)) asyncio.run(self.subscribe_async(filters, handler, timeout))
@@ -98,13 +116,13 @@ class NostrClient:
filters = [ filters = [
{ {
"authors": [self.key_manager.keys.public_key_hex()], "authors": [self.key_manager.keys.public_key_hex()],
"kinds": [EventKind.TEXT_NOTE, EventKind.ENCRYPTED_DIRECT_MESSAGE], "kinds": [1, 4],
"limit": 1, "limit": 1,
} }
] ]
events: list[Event] = [] events: list = []
async def handler(_client, _sid, evt: Event): async def handler(_client, _sid, evt):
events.append(evt) events.append(evt)
await self.subscribe_async(filters, handler) await self.subscribe_async(filters, handler)
@@ -113,32 +131,26 @@ class NostrClient:
return None return None
event = events[0] event = events[0]
content_base64 = event.content content_base64 = event.content()
if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE:
dm = EncryptedDirectMessage.from_event(event)
dm.decrypt(
self.key_manager.keys.private_key_hex(), public_key_hex=dm.pubkey
)
content_base64 = dm.cleartext_content
return content_base64 return content_base64
def retrieve_json_from_nostr(self) -> Optional[str]: def retrieve_json_from_nostr(self) -> Optional[str]:
return asyncio.run(self.retrieve_json_from_nostr_async()) return asyncio.run(self.retrieve_json_from_nostr_async())
async def do_post_async(self, text: str) -> None: async def do_post_async(self, text: str) -> None:
event = Event(kind=EventKind.TEXT_NOTE, content=text) keys = Keys.parse(self.key_manager.keys.private_key_hex())
event.pubkey = self.key_manager.keys.public_key_hex() event = (
event.created_at = int(time.time()) EventBuilder.text_note(text).build(keys.public_key()).sign_with_keys(keys)
event.sign(self.key_manager.keys.private_key_hex()) )
await self.publish_event_async(event) await self.publish_event_async(event)
async def subscribe_feed_async( async def subscribe_feed_async(
self, handler: Callable[[WebSocketRelayManager, str, Event], None] self, handler: Callable[[object, str, object], None]
) -> None: ) -> None:
filters = [ filters = [
{ {
"authors": [self.key_manager.keys.public_key_hex()], "authors": [self.key_manager.keys.public_key_hex()],
"kinds": [EventKind.TEXT_NOTE, EventKind.ENCRYPTED_DIRECT_MESSAGE], "kinds": [1, 4],
"limit": 100, "limit": 100,
} }
] ]
@@ -190,19 +202,12 @@ class NostrClient:
) -> bool: ) -> bool:
try: try:
content = base64.b64encode(encrypted_json).decode("utf-8") content = base64.b64encode(encrypted_json).decode("utf-8")
if to_pubkey: keys = Keys.parse(self.key_manager.keys.private_key_hex())
dm = EncryptedDirectMessage() event = (
dm.encrypt( EventBuilder.text_note(content)
private_key_hex=self.key_manager.keys.private_key_hex(), .build(keys.public_key())
cleartext_content=content, .sign_with_keys(keys)
recipient_pubkey=to_pubkey, )
)
event = dm.to_event()
else:
event = Event(kind=EventKind.TEXT_NOTE, content=content)
event.pubkey = self.key_manager.keys.public_key_hex()
event.created_at = int(time.time())
event.sign(self.key_manager.keys.private_key_hex())
self.publish_event(event) self.publish_event(event)
return True return True
except Exception as e: # pragma: no cover - defensive except Exception as e: # pragma: no cover - defensive
@@ -219,13 +224,14 @@ class NostrClient:
self.decrypt_and_save_index_from_nostr(encrypted_data) self.decrypt_and_save_index_from_nostr(encrypted_data)
async def close_client_pool_async(self) -> None: async def close_client_pool_async(self) -> None:
self.client_pool.close_all_relay_connections() uniffi_set_event_loop(asyncio.get_running_loop())
await self.client_pool.disconnect()
def close_client_pool(self) -> None: def close_client_pool(self) -> None:
self.client_pool.close_all_relay_connections() asyncio.run(self.close_client_pool_async())
async def safe_close_connection(self, client): # pragma: no cover - compatibility async def safe_close_connection(self, client): # pragma: no cover - compatibility
try: try:
await client.close_connection() await client.disconnect()
except Exception: except Exception:
pass pass

View File

@@ -11,7 +11,7 @@ bip85
pytest>=7.0 pytest>=7.0
pytest-cov pytest-cov
portalocker>=2.8 portalocker>=2.8
pynostr>=0.6.2 nostr-sdk>=0.42.1
websocket-client==1.7.0 websocket-client==1.7.0
websockets>=15.0.0 websockets>=15.0.0

View File

@@ -16,11 +16,11 @@ def test_nostr_client_uses_custom_relays():
enc_mgr = EncryptionManager(key, Path(tmpdir)) enc_mgr = EncryptionManager(key, Path(tmpdir))
custom_relays = ["wss://relay1", "wss://relay2"] custom_relays = ["wss://relay1", "wss://relay2"]
with patch("nostr.client.WebSocketRelayManager") as MockPool, patch( with patch("nostr.client.ClientBuilder") as MockBuilder, patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
): ), patch.object(NostrClient, "initialize_client_pool"):
mock_builder = MockBuilder.return_value
with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"): with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"):
client = NostrClient(enc_mgr, "fp", relays=custom_relays) client = NostrClient(enc_mgr, "fp", relays=custom_relays)
MockPool.assert_called_with()
assert client.relays == custom_relays assert client.relays == custom_relays

View File

@@ -4,9 +4,10 @@ import threading
import time import time
from websocket import create_connection from websocket import create_connection
import asyncio
import websockets import websockets
from nostr.key_manager import KeyManager from nostr.key_manager import KeyManager
from pynostr.event import Event, EventKind from nostr_sdk import nostr_sdk as sdk
class FakeRelay: class FakeRelay:
@@ -35,7 +36,7 @@ def run_relay(relay, host="localhost", port=8765):
asyncio.run(main()) asyncio.run(main())
def test_pynostr_send_receive(tmp_path): def test_nostr_sdk_send_receive(tmp_path):
relay = FakeRelay() relay = FakeRelay()
thread = threading.Thread(target=run_relay, args=(relay,), daemon=True) thread = threading.Thread(target=run_relay, args=(relay,), daemon=True)
thread.start() thread.start()
@@ -48,12 +49,13 @@ def test_pynostr_send_receive(tmp_path):
ws = create_connection("ws://localhost:8765") ws = create_connection("ws://localhost:8765")
event = Event(kind=EventKind.TEXT_NOTE, content="hello") keys = sdk.Keys.parse(km.get_private_key_hex())
event.pubkey = km.get_public_key_hex() event = (
event.created_at = int(time.time()) sdk.EventBuilder.text_note("hello")
event.sign(km.get_private_key_hex()) .build(keys.public_key())
.sign_with_keys(keys)
ws.send(event.to_message()) )
ws.send(json.dumps(["EVENT", json.loads(event.as_json())]))
sub_id = "1" sub_id = "1"
ws.send(json.dumps(["REQ", sub_id, {}])) ws.send(json.dumps(["REQ", sub_id, {}]))

View File

@@ -14,31 +14,40 @@ def setup_client(tmp_path):
key = Fernet.generate_key() key = Fernet.generate_key()
enc_mgr = EncryptionManager(key, tmp_path) enc_mgr = EncryptionManager(key, tmp_path)
with patch("nostr.client.WebSocketRelayManager"), patch( with patch("nostr.client.ClientBuilder"), patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"), patch.object( ) as MockKM, patch.object(NostrClient, "initialize_client_pool"), patch.object(
enc_mgr, "decrypt_parent_seed", return_value="seed" enc_mgr, "decrypt_parent_seed", return_value="seed"
): ):
km_inst = MockKM.return_value
km_inst.keys.private_key_hex.return_value = "1" * 64
km_inst.keys.public_key_hex.return_value = "2" * 64
client = NostrClient(enc_mgr, "fp") client = NostrClient(enc_mgr, "fp")
return client return client
class FakeEvent: class FakeEvent:
KIND_TEXT_NOTE = 1 def __init__(self):
KIND_ENCRYPT = 2 self._id = "id"
def __init__(self, kind, content, pub_key=None): def id(self):
self.kind = kind return self._id
self.content = content
self.pubkey = pub_key
self.id = "id"
def sign(self, _):
pass class FakeUnsignedEvent:
def sign_with_keys(self, _):
return FakeEvent()
class FakeBuilder:
def build(self, _):
return FakeUnsignedEvent()
def test_publish_json_success(): def test_publish_json_success():
with TemporaryDirectory() as tmpdir, patch("nostr.client.Event", FakeEvent): with TemporaryDirectory() as tmpdir, patch(
"nostr.client.EventBuilder.text_note", return_value=FakeBuilder()
):
client = setup_client(Path(tmpdir)) client = setup_client(Path(tmpdir))
with patch.object(client, "publish_event") as mock_pub: with patch.object(client, "publish_event") as mock_pub:
assert client.publish_json_to_nostr(b"data") is True assert client.publish_json_to_nostr(b"data") is True
@@ -46,7 +55,9 @@ def test_publish_json_success():
def test_publish_json_failure(): def test_publish_json_failure():
with TemporaryDirectory() as tmpdir, patch("nostr.client.Event", FakeEvent): with TemporaryDirectory() as tmpdir, patch(
"nostr.client.EventBuilder.text_note", return_value=FakeBuilder()
):
client = setup_client(Path(tmpdir)) client = setup_client(Path(tmpdir))
with patch.object(client, "publish_event", side_effect=Exception("boom")): with patch.object(client, "publish_event", side_effect=Exception("boom")):
assert client.publish_json_to_nostr(b"data") is False assert client.publish_json_to_nostr(b"data") is False

View File

@@ -27,7 +27,7 @@ def test_backup_and_publish_to_nostr():
with patch( with patch(
"nostr.client.NostrClient.publish_json_to_nostr", return_value=True "nostr.client.NostrClient.publish_json_to_nostr", return_value=True
) as mock_publish, patch("nostr.client.WebSocketRelayManager"), patch( ) as mock_publish, patch("nostr.client.ClientBuilder"), patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
), patch.object( ), patch.object(
NostrClient, "initialize_client_pool" NostrClient, "initialize_client_pool"