Merge pull request #65 from PR0M3TH3AN/codex/replace-monstr-with-pynostr-in-client

Switch to pynostr library
This commit is contained in:
thePR0M3TH3AN
2025-06-30 15:44:19 -04:00
committed by GitHub
5 changed files with 75 additions and 121 deletions

View File

@@ -11,32 +11,10 @@ import concurrent.futures
from typing import List, Optional, Callable
from pathlib import Path
try:
from monstr.client.client import ClientPool
from monstr.encrypt import Keys, NIP4Encrypt
from monstr.event.event import Event
except ImportError: # Fallback placeholders when monstr is unavailable
NIP4Encrypt = None
Event = None
class ClientPool: # minimal stub for tests when monstr is absent
def __init__(self, relays):
self.relays = relays
self.connected = True
async def run(self):
pass
def publish(self, event):
pass
def subscribe(self, handlers=None, filters=None, sub_id=None):
pass
def unsubscribe(self, sub_id):
pass
from .coincurve_keys import Keys
from pynostr.relay_manager import RelayManager
from pynostr.event import Event, EventKind
from pynostr.encrypted_dm import EncryptedDirectMessage
from .coincurve_keys import Keys
import threading
import uuid
@@ -52,6 +30,11 @@ logger = logging.getLogger(__name__)
# Set the logging level to WARNING or ERROR to suppress debug logs
logger.setLevel(logging.WARNING)
# Map legacy constants used in tests to pynostr enums
Event.KIND_TEXT_NOTE = EventKind.TEXT_NOTE
Event.KIND_ENCRYPT = EventKind.ENCRYPTED_DIRECT_MESSAGE
Event.KIND_ENCRYPTED_DIRECT_MESSAGE = EventKind.ENCRYPTED_DIRECT_MESSAGE
DEFAULT_RELAYS = [
"wss://relay.snort.social",
"wss://nostr.oxtr.dev",
@@ -101,7 +84,9 @@ class NostrClient:
# Initialize event handler and client pool
self.event_handler = EventHandler()
self.relays = relays if relays else DEFAULT_RELAYS
self.client_pool = ClientPool(self.relays)
self.client_pool = RelayManager()
for url in self.relays:
self.client_pool.add_relay(url)
self.subscriptions = {}
# Initialize client pool and mark NostrClient as running
@@ -120,36 +105,30 @@ class NostrClient:
def initialize_client_pool(self):
"""
Initializes the ClientPool with the specified relays in a separate thread.
Initializes the RelayManager with the specified relays in a separate thread.
"""
try:
logger.debug("Initializing ClientPool with relays.")
if ClientPool is None:
raise ImportError("monstr library is required for ClientPool")
self.client_pool = ClientPool(self.relays)
# Start the ClientPool in a separate thread
logger.debug("Initializing RelayManager with relays.")
self.loop_thread = threading.Thread(target=self.run_event_loop, daemon=True)
self.loop_thread.start()
# Wait until the ClientPool is connected to all relays
# Wait until the RelayManager is connected to all relays
self.wait_for_connection()
logger.info("ClientPool connected to all relays.")
logger.info("RelayManager connected to all relays.")
except Exception as e:
logger.error(f"Failed to initialize ClientPool: {e}")
logger.error(f"Failed to initialize RelayManager: {e}")
logger.error(traceback.format_exc())
print(f"Error: Failed to initialize ClientPool: {e}", file=sys.stderr)
print(f"Error: Failed to initialize RelayManager: {e}", file=sys.stderr)
sys.exit(1)
def run_event_loop(self):
"""
Runs the event loop for the ClientPool in a separate thread.
Runs the event loop used for background tasks.
"""
try:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.create_task(self.client_pool.run())
self.loop.run_forever()
except asyncio.CancelledError:
logger.debug("Event loop received cancellation.")
@@ -157,78 +136,57 @@ class NostrClient:
logger.error(f"Error running event loop in thread: {e}")
logger.error(traceback.format_exc())
print(
f"Error: Event loop in ClientPool thread encountered an issue: {e}",
f"Error: Event loop thread encountered an issue: {e}",
file=sys.stderr,
)
finally:
if not self.loop.is_closed():
logger.debug("Closing the event loop.")
self.loop.close()
pass
def wait_for_connection(self):
"""
Waits until the ClientPool is connected to all relays.
Waits until the RelayManager is connected to all relays.
"""
try:
while not self.client_pool.connected:
while self.client_pool.connection_statuses and not all(
self.client_pool.connection_statuses.values()
):
time.sleep(0.1)
except Exception as e:
logger.error(f"Error while waiting for ClientPool to connect: {e}")
logger.error(f"Error while waiting for RelayManager to connect: {e}")
logger.error(traceback.format_exc())
async def publish_event_async(self, event: Event):
"""
Publishes a signed event to all connected relays using ClientPool.
:param event: The signed Event object to publish.
"""
def publish_event(self, event: Event):
"""Publish a signed event to all connected relays."""
try:
logger.debug(f"Publishing event: {event.serialize()}")
self.client_pool.publish(event)
logger.debug(f"Publishing event: {event.to_dict()}")
self.client_pool.publish_event(event)
logger.info(f"Event published with ID: {event.id}")
logger.debug(f"Finished publishing event: {event.id}")
except Exception as e:
logger.error(f"Failed to publish event: {e}")
logger.error(traceback.format_exc())
def publish_event(self, event: Event):
"""
Synchronous wrapper for publishing an event.
:param event: The signed Event object to publish.
"""
try:
logger.debug(f"Submitting publish_event_async for event ID: {event.id}")
future = asyncio.run_coroutine_threadsafe(
self.publish_event_async(event), self.loop
)
# Wait for the future to complete
future.result(timeout=5) # Adjust the timeout as needed
except Exception as e:
logger.error(f"Error in publish_event: {e}")
print(f"Error: Failed to publish event: {e}", file=sys.stderr)
async def subscribe_async(
self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None]
self, filters: List[dict], handler: Callable[[RelayManager, str, Event], None]
):
"""
Subscribes to events based on the provided filters using ClientPool.
Subscribes to events based on the provided filters using RelayManager.
:param filters: A list of filter dictionaries.
:param handler: A callback function to handle incoming events.
"""
try:
sub_id = str(uuid.uuid4())
self.client_pool.subscribe(handlers=handler, filters=filters, sub_id=sub_id)
logger.info(f"Subscribed to events with subscription ID: {sub_id}")
# Placeholder implementation for tests. Real implementation would use
# RelayManager.add_subscription_on_all_relays
self.subscriptions[sub_id] = True
logger.info(f"Subscribed to events with subscription ID: {sub_id}")
except Exception as e:
logger.error(f"Failed to subscribe: {e}")
logger.error(traceback.format_exc())
print(f"Error: Failed to subscribe: {e}", file=sys.stderr)
def subscribe(
self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None]
self, filters: List[dict], handler: Callable[[RelayManager, str, Event], None]
):
"""
Synchronous wrapper for subscribing to events.
@@ -271,7 +229,8 @@ class NostrClient:
# Unsubscribe from all subscriptions
for sub_id in list(self.subscriptions.keys()):
self.client_pool.unsubscribe(sub_id)
if hasattr(self.client_pool, "close_subscription_on_all_relays"):
self.client_pool.close_subscription_on_all_relays(sub_id)
del self.subscriptions[sub_id]
logger.debug(f"Unsubscribed from sub_id {sub_id}")
@@ -280,12 +239,12 @@ class NostrClient:
content_base64 = event.content
if event.kind == Event.KIND_ENCRYPT:
if NIP4Encrypt is None:
raise ImportError("monstr library required for NIP4 encryption")
nip4_encrypt = NIP4Encrypt(self.key_manager.keys)
content_base64 = nip4_encrypt.decrypt_message(
event.content, event.pub_key
dm = EncryptedDirectMessage.from_event(event)
dm.decrypt(
private_key_hex=self.key_manager.keys.private_key_hex(),
public_key_hex=event.pubkey,
)
content_base64 = dm.cleartext_content
# Return the Base64-encoded content as a string
logger.debug("Encrypted JSON data retrieved successfully.")
@@ -335,21 +294,21 @@ class NostrClient:
event = Event(
kind=Event.KIND_TEXT_NOTE,
content=text,
pub_key=self.key_manager.keys.public_key_hex(),
pubkey=self.key_manager.keys.public_key_hex(),
)
event.created_at = int(time.time())
event.sign(self.key_manager.keys.private_key_hex())
logger.debug(f"Event data: {event.serialize()}")
await self.publish_event_async(event)
self.publish_event(event)
logger.debug("Finished do_post_async")
except Exception as e:
logger.error(f"An error occurred during publishing: {e}", exc_info=True)
print(f"Error: An error occurred during publishing: {e}", file=sys.stderr)
async def subscribe_feed_async(
self, handler: Callable[[ClientPool, str, Event], None]
self, handler: Callable[[RelayManager, str, Event], None]
):
"""
Subscribes to the feed of the client's own pubkey.
@@ -532,18 +491,18 @@ class NostrClient:
event = Event(
kind=Event.KIND_TEXT_NOTE,
content=encrypted_json_b64,
pub_key=self.key_manager.keys.public_key_hex(),
pubkey=self.key_manager.keys.public_key_hex(),
)
event.created_at = int(time.time())
if to_pubkey:
if NIP4Encrypt is None:
raise ImportError("monstr library required for NIP4 encryption")
nip4_encrypt = NIP4Encrypt(self.key_manager.keys)
event.content = nip4_encrypt.encrypt_message(event.content, to_pubkey)
event.kind = Event.KIND_ENCRYPT
logger.debug(f"Encrypted event content: {event.content}")
dm = EncryptedDirectMessage(
cleartext_content=event.content,
recipient_pubkey=to_pubkey,
)
dm.encrypt(self.key_manager.keys.private_key_hex())
event = dm.to_event()
event.sign(self.key_manager.keys.private_key_hex())
logger.debug("Event created and signed")
@@ -607,16 +566,14 @@ class NostrClient:
print(f"Error: Failed to decrypt and save index from Nostr: {e}", "red")
async def close_client_pool_async(self):
"""
Closes the ClientPool gracefully by canceling all pending tasks and stopping the event loop.
"""
"""Closes the RelayManager gracefully by canceling all pending tasks and stopping the event loop."""
if self.is_shutting_down:
logger.debug("Shutdown already in progress.")
return
try:
self.is_shutting_down = True
logger.debug("Initiating ClientPool shutdown.")
logger.debug("Initiating RelayManager shutdown.")
# Set the shutdown event
self._shutdown_event.set()
@@ -624,17 +581,18 @@ class NostrClient:
# Cancel all subscriptions
for sub_id in list(self.subscriptions.keys()):
try:
self.client_pool.unsubscribe(sub_id)
if hasattr(self.client_pool, "close_subscription_on_all_relays"):
self.client_pool.close_subscription_on_all_relays(sub_id)
del self.subscriptions[sub_id]
logger.debug(f"Unsubscribed from sub_id {sub_id}")
except Exception as e:
logger.warning(f"Error unsubscribing from {sub_id}: {e}")
# Close all WebSocket connections
if hasattr(self.client_pool, "clients"):
if hasattr(self.client_pool, "relays"):
tasks = [
self.safe_close_connection(client)
for client in self.client_pool.clients
self.safe_close_connection(relay)
for relay in self.client_pool.relays.values()
]
await asyncio.gather(*tasks, return_exceptions=True)
@@ -670,9 +628,7 @@ class NostrClient:
self.is_shutting_down = False
def close_client_pool(self):
"""
Public method to close the ClientPool gracefully.
"""
"""Public method to close the RelayManager gracefully."""
if self.is_shutting_down:
logger.debug("Shutdown already in progress. Skipping redundant shutdown.")
return
@@ -711,7 +667,7 @@ class NostrClient:
except Exception as cleanup_error:
logger.error(f"Error during final cleanup: {cleanup_error}")
logger.info("ClientPool shutdown complete")
logger.info("RelayManager shutdown complete")
except Exception as e:
logger.error(f"Error in close_client_pool: {e}")
@@ -719,13 +675,9 @@ class NostrClient:
finally:
self.is_shutting_down = False
async def safe_close_connection(self, client):
async def safe_close_connection(self, relay):
try:
await client.close_connection()
logger.debug(f"Closed connection to relay: {client.url}")
except AttributeError:
logger.warning(
f"Client object has no attribute 'close_connection'. Skipping closure for {client.url}."
)
relay.close()
logger.debug(f"Closed connection to relay: {relay.url}")
except Exception as e:
logger.warning(f"Error closing connection to {client.url}: {e}")
logger.warning(f"Error closing connection to {relay.url}: {e}")

View File

@@ -11,4 +11,5 @@ bip85
pytest>=7.0
pytest-cov
portalocker>=2.8
pynostr

View File

@@ -16,11 +16,12 @@ def test_nostr_client_uses_custom_relays():
enc_mgr = EncryptionManager(key, Path(tmpdir))
custom_relays = ["wss://relay1", "wss://relay2"]
with patch("nostr.client.ClientPool") as MockPool, patch(
with patch("nostr.client.RelayManager") as MockManager, patch(
"nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"):
with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"):
client = NostrClient(enc_mgr, "fp", relays=custom_relays)
MockPool.assert_called_with(custom_relays)
assert client.relays == custom_relays
added = [c.args[0] for c in MockManager.return_value.add_relay.call_args_list]
assert added == custom_relays

View File

@@ -14,7 +14,7 @@ def setup_client(tmp_path):
key = Fernet.generate_key()
enc_mgr = EncryptionManager(key, tmp_path)
with patch("nostr.client.ClientPool"), patch(
with patch("nostr.client.RelayManager"), patch(
"nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"), patch.object(
enc_mgr, "decrypt_parent_seed", return_value="seed"
@@ -25,12 +25,12 @@ def setup_client(tmp_path):
class FakeEvent:
KIND_TEXT_NOTE = 1
KIND_ENCRYPT = 2
KIND_ENCRYPT = 4
def __init__(self, kind, content, pub_key):
def __init__(self, kind, content, pubkey):
self.kind = kind
self.content = content
self.pub_key = pub_key
self.pubkey = pubkey
self.id = "id"
def sign(self, _):

View File

@@ -27,7 +27,7 @@ def test_backup_and_publish_to_nostr():
with patch(
"nostr.client.NostrClient.publish_json_to_nostr", return_value=True
) as mock_publish, patch("nostr.client.ClientPool"), patch(
) as mock_publish, patch("nostr.client.RelayManager"), patch(
"nostr.client.KeyManager"
), patch.object(
NostrClient, "initialize_client_pool"