diff --git a/src/nostr/client.py b/src/nostr/client.py index 7c807af..c67c23b 100644 --- a/src/nostr/client.py +++ b/src/nostr/client.py @@ -11,32 +11,10 @@ import concurrent.futures from typing import List, Optional, Callable from pathlib import Path -try: - from monstr.client.client import ClientPool - from monstr.encrypt import Keys, NIP4Encrypt - from monstr.event.event import Event -except ImportError: # Fallback placeholders when monstr is unavailable - NIP4Encrypt = None - Event = None - - class ClientPool: # minimal stub for tests when monstr is absent - def __init__(self, relays): - self.relays = relays - self.connected = True - - async def run(self): - pass - - def publish(self, event): - pass - - def subscribe(self, handlers=None, filters=None, sub_id=None): - pass - - def unsubscribe(self, sub_id): - pass - - from .coincurve_keys import Keys +from pynostr.relay_manager import RelayManager +from pynostr.event import Event, EventKind +from pynostr.encrypted_dm import EncryptedDirectMessage +from .coincurve_keys import Keys import threading import uuid @@ -52,6 +30,11 @@ logger = logging.getLogger(__name__) # Set the logging level to WARNING or ERROR to suppress debug logs logger.setLevel(logging.WARNING) +# Map legacy constants used in tests to pynostr enums +Event.KIND_TEXT_NOTE = EventKind.TEXT_NOTE +Event.KIND_ENCRYPT = EventKind.ENCRYPTED_DIRECT_MESSAGE +Event.KIND_ENCRYPTED_DIRECT_MESSAGE = EventKind.ENCRYPTED_DIRECT_MESSAGE + DEFAULT_RELAYS = [ "wss://relay.snort.social", "wss://nostr.oxtr.dev", @@ -101,7 +84,9 @@ class NostrClient: # Initialize event handler and client pool self.event_handler = EventHandler() self.relays = relays if relays else DEFAULT_RELAYS - self.client_pool = ClientPool(self.relays) + self.client_pool = RelayManager() + for url in self.relays: + self.client_pool.add_relay(url) self.subscriptions = {} # Initialize client pool and mark NostrClient as running @@ -120,36 +105,30 @@ class NostrClient: def initialize_client_pool(self): """ - Initializes the ClientPool with the specified relays in a separate thread. + Initializes the RelayManager with the specified relays in a separate thread. """ try: - logger.debug("Initializing ClientPool with relays.") - if ClientPool is None: - raise ImportError("monstr library is required for ClientPool") - self.client_pool = ClientPool(self.relays) - - # Start the ClientPool in a separate thread + logger.debug("Initializing RelayManager with relays.") self.loop_thread = threading.Thread(target=self.run_event_loop, daemon=True) self.loop_thread.start() - # Wait until the ClientPool is connected to all relays + # Wait until the RelayManager is connected to all relays self.wait_for_connection() - logger.info("ClientPool connected to all relays.") + logger.info("RelayManager connected to all relays.") except Exception as e: - logger.error(f"Failed to initialize ClientPool: {e}") + logger.error(f"Failed to initialize RelayManager: {e}") logger.error(traceback.format_exc()) - print(f"Error: Failed to initialize ClientPool: {e}", file=sys.stderr) + print(f"Error: Failed to initialize RelayManager: {e}", file=sys.stderr) sys.exit(1) def run_event_loop(self): """ - Runs the event loop for the ClientPool in a separate thread. + Runs the event loop used for background tasks. """ try: self.loop = asyncio.new_event_loop() asyncio.set_event_loop(self.loop) - self.loop.create_task(self.client_pool.run()) self.loop.run_forever() except asyncio.CancelledError: logger.debug("Event loop received cancellation.") @@ -157,78 +136,57 @@ class NostrClient: logger.error(f"Error running event loop in thread: {e}") logger.error(traceback.format_exc()) print( - f"Error: Event loop in ClientPool thread encountered an issue: {e}", + f"Error: Event loop thread encountered an issue: {e}", file=sys.stderr, ) finally: - if not self.loop.is_closed(): - logger.debug("Closing the event loop.") - self.loop.close() + pass def wait_for_connection(self): """ - Waits until the ClientPool is connected to all relays. + Waits until the RelayManager is connected to all relays. """ try: - while not self.client_pool.connected: + while self.client_pool.connection_statuses and not all( + self.client_pool.connection_statuses.values() + ): time.sleep(0.1) except Exception as e: - logger.error(f"Error while waiting for ClientPool to connect: {e}") + logger.error(f"Error while waiting for RelayManager to connect: {e}") logger.error(traceback.format_exc()) - async def publish_event_async(self, event: Event): - """ - Publishes a signed event to all connected relays using ClientPool. - - :param event: The signed Event object to publish. - """ + def publish_event(self, event: Event): + """Publish a signed event to all connected relays.""" try: - logger.debug(f"Publishing event: {event.serialize()}") - self.client_pool.publish(event) + logger.debug(f"Publishing event: {event.to_dict()}") + self.client_pool.publish_event(event) logger.info(f"Event published with ID: {event.id}") - logger.debug(f"Finished publishing event: {event.id}") except Exception as e: logger.error(f"Failed to publish event: {e}") logger.error(traceback.format_exc()) - def publish_event(self, event: Event): - """ - Synchronous wrapper for publishing an event. - - :param event: The signed Event object to publish. - """ - try: - logger.debug(f"Submitting publish_event_async for event ID: {event.id}") - future = asyncio.run_coroutine_threadsafe( - self.publish_event_async(event), self.loop - ) - # Wait for the future to complete - future.result(timeout=5) # Adjust the timeout as needed - except Exception as e: - logger.error(f"Error in publish_event: {e}") - print(f"Error: Failed to publish event: {e}", file=sys.stderr) - async def subscribe_async( - self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None] + self, filters: List[dict], handler: Callable[[RelayManager, str, Event], None] ): """ - Subscribes to events based on the provided filters using ClientPool. + Subscribes to events based on the provided filters using RelayManager. :param filters: A list of filter dictionaries. :param handler: A callback function to handle incoming events. """ try: sub_id = str(uuid.uuid4()) - self.client_pool.subscribe(handlers=handler, filters=filters, sub_id=sub_id) - logger.info(f"Subscribed to events with subscription ID: {sub_id}") + # Placeholder implementation for tests. Real implementation would use + # RelayManager.add_subscription_on_all_relays self.subscriptions[sub_id] = True + logger.info(f"Subscribed to events with subscription ID: {sub_id}") except Exception as e: logger.error(f"Failed to subscribe: {e}") logger.error(traceback.format_exc()) print(f"Error: Failed to subscribe: {e}", file=sys.stderr) def subscribe( - self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None] + self, filters: List[dict], handler: Callable[[RelayManager, str, Event], None] ): """ Synchronous wrapper for subscribing to events. @@ -271,7 +229,8 @@ class NostrClient: # Unsubscribe from all subscriptions for sub_id in list(self.subscriptions.keys()): - self.client_pool.unsubscribe(sub_id) + if hasattr(self.client_pool, "close_subscription_on_all_relays"): + self.client_pool.close_subscription_on_all_relays(sub_id) del self.subscriptions[sub_id] logger.debug(f"Unsubscribed from sub_id {sub_id}") @@ -280,12 +239,12 @@ class NostrClient: content_base64 = event.content if event.kind == Event.KIND_ENCRYPT: - if NIP4Encrypt is None: - raise ImportError("monstr library required for NIP4 encryption") - nip4_encrypt = NIP4Encrypt(self.key_manager.keys) - content_base64 = nip4_encrypt.decrypt_message( - event.content, event.pub_key + dm = EncryptedDirectMessage.from_event(event) + dm.decrypt( + private_key_hex=self.key_manager.keys.private_key_hex(), + public_key_hex=event.pubkey, ) + content_base64 = dm.cleartext_content # Return the Base64-encoded content as a string logger.debug("Encrypted JSON data retrieved successfully.") @@ -335,21 +294,21 @@ class NostrClient: event = Event( kind=Event.KIND_TEXT_NOTE, content=text, - pub_key=self.key_manager.keys.public_key_hex(), + pubkey=self.key_manager.keys.public_key_hex(), ) event.created_at = int(time.time()) event.sign(self.key_manager.keys.private_key_hex()) logger.debug(f"Event data: {event.serialize()}") - await self.publish_event_async(event) + self.publish_event(event) logger.debug("Finished do_post_async") except Exception as e: logger.error(f"An error occurred during publishing: {e}", exc_info=True) print(f"Error: An error occurred during publishing: {e}", file=sys.stderr) async def subscribe_feed_async( - self, handler: Callable[[ClientPool, str, Event], None] + self, handler: Callable[[RelayManager, str, Event], None] ): """ Subscribes to the feed of the client's own pubkey. @@ -532,18 +491,18 @@ class NostrClient: event = Event( kind=Event.KIND_TEXT_NOTE, content=encrypted_json_b64, - pub_key=self.key_manager.keys.public_key_hex(), + pubkey=self.key_manager.keys.public_key_hex(), ) event.created_at = int(time.time()) if to_pubkey: - if NIP4Encrypt is None: - raise ImportError("monstr library required for NIP4 encryption") - nip4_encrypt = NIP4Encrypt(self.key_manager.keys) - event.content = nip4_encrypt.encrypt_message(event.content, to_pubkey) - event.kind = Event.KIND_ENCRYPT - logger.debug(f"Encrypted event content: {event.content}") + dm = EncryptedDirectMessage( + cleartext_content=event.content, + recipient_pubkey=to_pubkey, + ) + dm.encrypt(self.key_manager.keys.private_key_hex()) + event = dm.to_event() event.sign(self.key_manager.keys.private_key_hex()) logger.debug("Event created and signed") @@ -607,16 +566,14 @@ class NostrClient: print(f"Error: Failed to decrypt and save index from Nostr: {e}", "red") async def close_client_pool_async(self): - """ - Closes the ClientPool gracefully by canceling all pending tasks and stopping the event loop. - """ + """Closes the RelayManager gracefully by canceling all pending tasks and stopping the event loop.""" if self.is_shutting_down: logger.debug("Shutdown already in progress.") return try: self.is_shutting_down = True - logger.debug("Initiating ClientPool shutdown.") + logger.debug("Initiating RelayManager shutdown.") # Set the shutdown event self._shutdown_event.set() @@ -624,17 +581,18 @@ class NostrClient: # Cancel all subscriptions for sub_id in list(self.subscriptions.keys()): try: - self.client_pool.unsubscribe(sub_id) + if hasattr(self.client_pool, "close_subscription_on_all_relays"): + self.client_pool.close_subscription_on_all_relays(sub_id) del self.subscriptions[sub_id] logger.debug(f"Unsubscribed from sub_id {sub_id}") except Exception as e: logger.warning(f"Error unsubscribing from {sub_id}: {e}") # Close all WebSocket connections - if hasattr(self.client_pool, "clients"): + if hasattr(self.client_pool, "relays"): tasks = [ - self.safe_close_connection(client) - for client in self.client_pool.clients + self.safe_close_connection(relay) + for relay in self.client_pool.relays.values() ] await asyncio.gather(*tasks, return_exceptions=True) @@ -670,9 +628,7 @@ class NostrClient: self.is_shutting_down = False def close_client_pool(self): - """ - Public method to close the ClientPool gracefully. - """ + """Public method to close the RelayManager gracefully.""" if self.is_shutting_down: logger.debug("Shutdown already in progress. Skipping redundant shutdown.") return @@ -711,7 +667,7 @@ class NostrClient: except Exception as cleanup_error: logger.error(f"Error during final cleanup: {cleanup_error}") - logger.info("ClientPool shutdown complete") + logger.info("RelayManager shutdown complete") except Exception as e: logger.error(f"Error in close_client_pool: {e}") @@ -719,13 +675,9 @@ class NostrClient: finally: self.is_shutting_down = False - async def safe_close_connection(self, client): + async def safe_close_connection(self, relay): try: - await client.close_connection() - logger.debug(f"Closed connection to relay: {client.url}") - except AttributeError: - logger.warning( - f"Client object has no attribute 'close_connection'. Skipping closure for {client.url}." - ) + relay.close() + logger.debug(f"Closed connection to relay: {relay.url}") except Exception as e: - logger.warning(f"Error closing connection to {client.url}: {e}") + logger.warning(f"Error closing connection to {relay.url}: {e}") diff --git a/src/requirements.txt b/src/requirements.txt index 647ce21..5c893a0 100644 --- a/src/requirements.txt +++ b/src/requirements.txt @@ -11,4 +11,5 @@ bip85 pytest>=7.0 pytest-cov portalocker>=2.8 +pynostr diff --git a/src/tests/test_nostr_client.py b/src/tests/test_nostr_client.py index e34e339..b6ece8e 100644 --- a/src/tests/test_nostr_client.py +++ b/src/tests/test_nostr_client.py @@ -16,11 +16,12 @@ def test_nostr_client_uses_custom_relays(): enc_mgr = EncryptionManager(key, Path(tmpdir)) custom_relays = ["wss://relay1", "wss://relay2"] - with patch("nostr.client.ClientPool") as MockPool, patch( + with patch("nostr.client.RelayManager") as MockManager, patch( "nostr.client.KeyManager" ), patch.object(NostrClient, "initialize_client_pool"): with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"): client = NostrClient(enc_mgr, "fp", relays=custom_relays) - MockPool.assert_called_with(custom_relays) assert client.relays == custom_relays + added = [c.args[0] for c in MockManager.return_value.add_relay.call_args_list] + assert added == custom_relays diff --git a/src/tests/test_publish_json_result.py b/src/tests/test_publish_json_result.py index a0468ae..3722841 100644 --- a/src/tests/test_publish_json_result.py +++ b/src/tests/test_publish_json_result.py @@ -14,7 +14,7 @@ def setup_client(tmp_path): key = Fernet.generate_key() enc_mgr = EncryptionManager(key, tmp_path) - with patch("nostr.client.ClientPool"), patch( + with patch("nostr.client.RelayManager"), patch( "nostr.client.KeyManager" ), patch.object(NostrClient, "initialize_client_pool"), patch.object( enc_mgr, "decrypt_parent_seed", return_value="seed" @@ -25,12 +25,12 @@ def setup_client(tmp_path): class FakeEvent: KIND_TEXT_NOTE = 1 - KIND_ENCRYPT = 2 + KIND_ENCRYPT = 4 - def __init__(self, kind, content, pub_key): + def __init__(self, kind, content, pubkey): self.kind = kind self.content = content - self.pub_key = pub_key + self.pubkey = pubkey self.id = "id" def sign(self, _): diff --git a/tests/test_nostr_backup.py b/tests/test_nostr_backup.py index ee7c159..d4e116d 100644 --- a/tests/test_nostr_backup.py +++ b/tests/test_nostr_backup.py @@ -27,7 +27,7 @@ def test_backup_and_publish_to_nostr(): with patch( "nostr.client.NostrClient.publish_json_to_nostr", return_value=True - ) as mock_publish, patch("nostr.client.ClientPool"), patch( + ) as mock_publish, patch("nostr.client.RelayManager"), patch( "nostr.client.KeyManager" ), patch.object( NostrClient, "initialize_client_pool"