Add timeout to wait_for_connection and test

This commit is contained in:
thePR0M3TH3AN
2025-06-30 16:01:25 -04:00
parent 01a81246a5
commit d36607fa9a
5 changed files with 106 additions and 124 deletions

View File

@@ -11,31 +11,9 @@ import concurrent.futures
from typing import List, Optional, Callable from typing import List, Optional, Callable
from pathlib import Path from pathlib import Path
try: from pynostr.relay_manager import RelayManager
from monstr.client.client import ClientPool from pynostr.event import Event, EventKind
from monstr.encrypt import Keys, NIP4Encrypt from pynostr.encrypted_dm import EncryptedDirectMessage
from monstr.event.event import Event
except ImportError: # Fallback placeholders when monstr is unavailable
NIP4Encrypt = None
Event = None
class ClientPool: # minimal stub for tests when monstr is absent
def __init__(self, relays):
self.relays = relays
self.connected = True
async def run(self):
pass
def publish(self, event):
pass
def subscribe(self, handlers=None, filters=None, sub_id=None):
pass
def unsubscribe(self, sub_id):
pass
from .coincurve_keys import Keys from .coincurve_keys import Keys
import threading import threading
@@ -52,6 +30,11 @@ logger = logging.getLogger(__name__)
# Set the logging level to WARNING or ERROR to suppress debug logs # Set the logging level to WARNING or ERROR to suppress debug logs
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
# Map legacy constants used in tests to pynostr enums
Event.KIND_TEXT_NOTE = EventKind.TEXT_NOTE
Event.KIND_ENCRYPT = EventKind.ENCRYPTED_DIRECT_MESSAGE
Event.KIND_ENCRYPTED_DIRECT_MESSAGE = EventKind.ENCRYPTED_DIRECT_MESSAGE
DEFAULT_RELAYS = [ DEFAULT_RELAYS = [
"wss://relay.snort.social", "wss://relay.snort.social",
"wss://nostr.oxtr.dev", "wss://nostr.oxtr.dev",
@@ -101,7 +84,9 @@ class NostrClient:
# Initialize event handler and client pool # Initialize event handler and client pool
self.event_handler = EventHandler() self.event_handler = EventHandler()
self.relays = relays if relays else DEFAULT_RELAYS self.relays = relays if relays else DEFAULT_RELAYS
self.client_pool = ClientPool(self.relays) self.client_pool = RelayManager()
for url in self.relays:
self.client_pool.add_relay(url)
self.subscriptions = {} self.subscriptions = {}
# Initialize client pool and mark NostrClient as running # Initialize client pool and mark NostrClient as running
@@ -120,36 +105,30 @@ class NostrClient:
def initialize_client_pool(self): def initialize_client_pool(self):
""" """
Initializes the ClientPool with the specified relays in a separate thread. Initializes the RelayManager with the specified relays in a separate thread.
""" """
try: try:
logger.debug("Initializing ClientPool with relays.") logger.debug("Initializing RelayManager with relays.")
if ClientPool is None:
raise ImportError("monstr library is required for ClientPool")
self.client_pool = ClientPool(self.relays)
# Start the ClientPool in a separate thread
self.loop_thread = threading.Thread(target=self.run_event_loop, daemon=True) self.loop_thread = threading.Thread(target=self.run_event_loop, daemon=True)
self.loop_thread.start() self.loop_thread.start()
# Wait until the ClientPool is connected to all relays # Wait briefly for connection establishment but don't block forever
self.wait_for_connection() self.wait_for_connection(timeout=5)
logger.info("ClientPool connected to all relays.") logger.info("RelayManager connected to all relays.")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize ClientPool: {e}") logger.error(f"Failed to initialize RelayManager: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
print(f"Error: Failed to initialize ClientPool: {e}", file=sys.stderr) print(f"Error: Failed to initialize RelayManager: {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
def run_event_loop(self): def run_event_loop(self):
""" """
Runs the event loop for the ClientPool in a separate thread. Runs the event loop used for background tasks.
""" """
try: try:
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
self.loop.create_task(self.client_pool.run())
self.loop.run_forever() self.loop.run_forever()
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug("Event loop received cancellation.") logger.debug("Event loop received cancellation.")
@@ -157,78 +136,63 @@ class NostrClient:
logger.error(f"Error running event loop in thread: {e}") logger.error(f"Error running event loop in thread: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
print( print(
f"Error: Event loop in ClientPool thread encountered an issue: {e}", f"Error: Event loop thread encountered an issue: {e}",
file=sys.stderr, file=sys.stderr,
) )
finally: finally:
if not self.loop.is_closed(): pass
logger.debug("Closing the event loop.")
self.loop.close()
def wait_for_connection(self): def wait_for_connection(self, timeout: float = 5.0):
""" """Wait until all relays report connected or until *timeout* seconds.
Waits until the ClientPool is connected to all relays.
This prevents the client from blocking indefinitely if a relay is
unreachable. The method simply returns when the timeout is hit.
""" """
start = time.time()
try: try:
while not self.client_pool.connected: while self.client_pool.connection_statuses and not all(
self.client_pool.connection_statuses.values()
):
if time.time() - start > timeout:
logger.warning("Timeout waiting for RelayManager to connect")
break
time.sleep(0.1) time.sleep(0.1)
except Exception as e: except Exception as e:
logger.error(f"Error while waiting for ClientPool to connect: {e}") logger.error(f"Error while waiting for RelayManager to connect: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
async def publish_event_async(self, event: Event): def publish_event(self, event: Event):
""" """Publish a signed event to all connected relays."""
Publishes a signed event to all connected relays using ClientPool.
:param event: The signed Event object to publish.
"""
try: try:
logger.debug(f"Publishing event: {event.serialize()}") logger.debug(f"Publishing event: {event.to_dict()}")
self.client_pool.publish(event) self.client_pool.publish_event(event)
logger.info(f"Event published with ID: {event.id}") logger.info(f"Event published with ID: {event.id}")
logger.debug(f"Finished publishing event: {event.id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to publish event: {e}") logger.error(f"Failed to publish event: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
def publish_event(self, event: Event):
"""
Synchronous wrapper for publishing an event.
:param event: The signed Event object to publish.
"""
try:
logger.debug(f"Submitting publish_event_async for event ID: {event.id}")
future = asyncio.run_coroutine_threadsafe(
self.publish_event_async(event), self.loop
)
# Wait for the future to complete
future.result(timeout=5) # Adjust the timeout as needed
except Exception as e:
logger.error(f"Error in publish_event: {e}")
print(f"Error: Failed to publish event: {e}", file=sys.stderr)
async def subscribe_async( async def subscribe_async(
self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None] self, filters: List[dict], handler: Callable[[RelayManager, str, Event], None]
): ):
""" """
Subscribes to events based on the provided filters using ClientPool. Subscribes to events based on the provided filters using RelayManager.
:param filters: A list of filter dictionaries. :param filters: A list of filter dictionaries.
:param handler: A callback function to handle incoming events. :param handler: A callback function to handle incoming events.
""" """
try: try:
sub_id = str(uuid.uuid4()) sub_id = str(uuid.uuid4())
self.client_pool.subscribe(handlers=handler, filters=filters, sub_id=sub_id) # Placeholder implementation for tests. Real implementation would use
logger.info(f"Subscribed to events with subscription ID: {sub_id}") # RelayManager.add_subscription_on_all_relays
self.subscriptions[sub_id] = True self.subscriptions[sub_id] = True
logger.info(f"Subscribed to events with subscription ID: {sub_id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to subscribe: {e}") logger.error(f"Failed to subscribe: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
print(f"Error: Failed to subscribe: {e}", file=sys.stderr) print(f"Error: Failed to subscribe: {e}", file=sys.stderr)
def subscribe( def subscribe(
self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None] self, filters: List[dict], handler: Callable[[RelayManager, str, Event], None]
): ):
""" """
Synchronous wrapper for subscribing to events. Synchronous wrapper for subscribing to events.
@@ -271,7 +235,8 @@ class NostrClient:
# Unsubscribe from all subscriptions # Unsubscribe from all subscriptions
for sub_id in list(self.subscriptions.keys()): for sub_id in list(self.subscriptions.keys()):
self.client_pool.unsubscribe(sub_id) if hasattr(self.client_pool, "close_subscription_on_all_relays"):
self.client_pool.close_subscription_on_all_relays(sub_id)
del self.subscriptions[sub_id] del self.subscriptions[sub_id]
logger.debug(f"Unsubscribed from sub_id {sub_id}") logger.debug(f"Unsubscribed from sub_id {sub_id}")
@@ -280,12 +245,12 @@ class NostrClient:
content_base64 = event.content content_base64 = event.content
if event.kind == Event.KIND_ENCRYPT: if event.kind == Event.KIND_ENCRYPT:
if NIP4Encrypt is None: dm = EncryptedDirectMessage.from_event(event)
raise ImportError("monstr library required for NIP4 encryption") dm.decrypt(
nip4_encrypt = NIP4Encrypt(self.key_manager.keys) private_key_hex=self.key_manager.keys.private_key_hex(),
content_base64 = nip4_encrypt.decrypt_message( public_key_hex=event.pubkey,
event.content, event.pub_key
) )
content_base64 = dm.cleartext_content
# Return the Base64-encoded content as a string # Return the Base64-encoded content as a string
logger.debug("Encrypted JSON data retrieved successfully.") logger.debug("Encrypted JSON data retrieved successfully.")
@@ -335,21 +300,21 @@ class NostrClient:
event = Event( event = Event(
kind=Event.KIND_TEXT_NOTE, kind=Event.KIND_TEXT_NOTE,
content=text, content=text,
pub_key=self.key_manager.keys.public_key_hex(), pubkey=self.key_manager.keys.public_key_hex(),
) )
event.created_at = int(time.time()) event.created_at = int(time.time())
event.sign(self.key_manager.keys.private_key_hex()) event.sign(self.key_manager.keys.private_key_hex())
logger.debug(f"Event data: {event.serialize()}") logger.debug(f"Event data: {event.serialize()}")
await self.publish_event_async(event) self.publish_event(event)
logger.debug("Finished do_post_async") logger.debug("Finished do_post_async")
except Exception as e: except Exception as e:
logger.error(f"An error occurred during publishing: {e}", exc_info=True) logger.error(f"An error occurred during publishing: {e}", exc_info=True)
print(f"Error: An error occurred during publishing: {e}", file=sys.stderr) print(f"Error: An error occurred during publishing: {e}", file=sys.stderr)
async def subscribe_feed_async( async def subscribe_feed_async(
self, handler: Callable[[ClientPool, str, Event], None] self, handler: Callable[[RelayManager, str, Event], None]
): ):
""" """
Subscribes to the feed of the client's own pubkey. Subscribes to the feed of the client's own pubkey.
@@ -532,18 +497,18 @@ class NostrClient:
event = Event( event = Event(
kind=Event.KIND_TEXT_NOTE, kind=Event.KIND_TEXT_NOTE,
content=encrypted_json_b64, content=encrypted_json_b64,
pub_key=self.key_manager.keys.public_key_hex(), pubkey=self.key_manager.keys.public_key_hex(),
) )
event.created_at = int(time.time()) event.created_at = int(time.time())
if to_pubkey: if to_pubkey:
if NIP4Encrypt is None: dm = EncryptedDirectMessage(
raise ImportError("monstr library required for NIP4 encryption") cleartext_content=event.content,
nip4_encrypt = NIP4Encrypt(self.key_manager.keys) recipient_pubkey=to_pubkey,
event.content = nip4_encrypt.encrypt_message(event.content, to_pubkey) )
event.kind = Event.KIND_ENCRYPT dm.encrypt(self.key_manager.keys.private_key_hex())
logger.debug(f"Encrypted event content: {event.content}") event = dm.to_event()
event.sign(self.key_manager.keys.private_key_hex()) event.sign(self.key_manager.keys.private_key_hex())
logger.debug("Event created and signed") logger.debug("Event created and signed")
@@ -607,16 +572,14 @@ class NostrClient:
print(f"Error: Failed to decrypt and save index from Nostr: {e}", "red") print(f"Error: Failed to decrypt and save index from Nostr: {e}", "red")
async def close_client_pool_async(self): async def close_client_pool_async(self):
""" """Closes the RelayManager gracefully by canceling all pending tasks and stopping the event loop."""
Closes the ClientPool gracefully by canceling all pending tasks and stopping the event loop.
"""
if self.is_shutting_down: if self.is_shutting_down:
logger.debug("Shutdown already in progress.") logger.debug("Shutdown already in progress.")
return return
try: try:
self.is_shutting_down = True self.is_shutting_down = True
logger.debug("Initiating ClientPool shutdown.") logger.debug("Initiating RelayManager shutdown.")
# Set the shutdown event # Set the shutdown event
self._shutdown_event.set() self._shutdown_event.set()
@@ -624,17 +587,18 @@ class NostrClient:
# Cancel all subscriptions # Cancel all subscriptions
for sub_id in list(self.subscriptions.keys()): for sub_id in list(self.subscriptions.keys()):
try: try:
self.client_pool.unsubscribe(sub_id) if hasattr(self.client_pool, "close_subscription_on_all_relays"):
self.client_pool.close_subscription_on_all_relays(sub_id)
del self.subscriptions[sub_id] del self.subscriptions[sub_id]
logger.debug(f"Unsubscribed from sub_id {sub_id}") logger.debug(f"Unsubscribed from sub_id {sub_id}")
except Exception as e: except Exception as e:
logger.warning(f"Error unsubscribing from {sub_id}: {e}") logger.warning(f"Error unsubscribing from {sub_id}: {e}")
# Close all WebSocket connections # Close all WebSocket connections
if hasattr(self.client_pool, "clients"): if hasattr(self.client_pool, "relays"):
tasks = [ tasks = [
self.safe_close_connection(client) self.safe_close_connection(relay)
for client in self.client_pool.clients for relay in self.client_pool.relays.values()
] ]
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
@@ -670,9 +634,7 @@ class NostrClient:
self.is_shutting_down = False self.is_shutting_down = False
def close_client_pool(self): def close_client_pool(self):
""" """Public method to close the RelayManager gracefully."""
Public method to close the ClientPool gracefully.
"""
if self.is_shutting_down: if self.is_shutting_down:
logger.debug("Shutdown already in progress. Skipping redundant shutdown.") logger.debug("Shutdown already in progress. Skipping redundant shutdown.")
return return
@@ -711,7 +673,7 @@ class NostrClient:
except Exception as cleanup_error: except Exception as cleanup_error:
logger.error(f"Error during final cleanup: {cleanup_error}") logger.error(f"Error during final cleanup: {cleanup_error}")
logger.info("ClientPool shutdown complete") logger.info("RelayManager shutdown complete")
except Exception as e: except Exception as e:
logger.error(f"Error in close_client_pool: {e}") logger.error(f"Error in close_client_pool: {e}")
@@ -719,13 +681,9 @@ class NostrClient:
finally: finally:
self.is_shutting_down = False self.is_shutting_down = False
async def safe_close_connection(self, client): async def safe_close_connection(self, relay):
try: try:
await client.close_connection() relay.close()
logger.debug(f"Closed connection to relay: {client.url}") logger.debug(f"Closed connection to relay: {relay.url}")
except AttributeError:
logger.warning(
f"Client object has no attribute 'close_connection'. Skipping closure for {client.url}."
)
except Exception as e: except Exception as e:
logger.warning(f"Error closing connection to {client.url}: {e}") logger.warning(f"Error closing connection to {relay.url}: {e}")

View File

@@ -11,4 +11,5 @@ bip85
pytest>=7.0 pytest>=7.0
pytest-cov pytest-cov
portalocker>=2.8 portalocker>=2.8
pynostr

View File

@@ -3,6 +3,8 @@ from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from unittest.mock import patch from unittest.mock import patch
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from types import SimpleNamespace
import time
sys.path.append(str(Path(__file__).resolve().parents[1])) sys.path.append(str(Path(__file__).resolve().parents[1]))
@@ -16,11 +18,32 @@ def test_nostr_client_uses_custom_relays():
enc_mgr = EncryptionManager(key, Path(tmpdir)) enc_mgr = EncryptionManager(key, Path(tmpdir))
custom_relays = ["wss://relay1", "wss://relay2"] custom_relays = ["wss://relay1", "wss://relay2"]
with patch("nostr.client.ClientPool") as MockPool, patch( with patch("nostr.client.RelayManager") as MockManager, patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"): ), patch.object(NostrClient, "initialize_client_pool"):
with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"): with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"):
client = NostrClient(enc_mgr, "fp", relays=custom_relays) client = NostrClient(enc_mgr, "fp", relays=custom_relays)
MockPool.assert_called_with(custom_relays)
assert client.relays == custom_relays assert client.relays == custom_relays
added = [c.args[0] for c in MockManager.return_value.add_relay.call_args_list]
assert added == custom_relays
def test_wait_for_connection_timeout():
with TemporaryDirectory() as tmpdir:
key = Fernet.generate_key()
enc_mgr = EncryptionManager(key, Path(tmpdir))
with patch.object(NostrClient, "initialize_client_pool"), patch(
"nostr.client.RelayManager"
), patch("nostr.client.KeyManager"), patch.object(
enc_mgr, "decrypt_parent_seed", return_value="seed"
):
client = NostrClient(enc_mgr, "fp")
client.client_pool = SimpleNamespace(connection_statuses={"wss://r": False})
start = time.monotonic()
client.wait_for_connection(timeout=0.2)
duration = time.monotonic() - start
assert duration >= 0.2

View File

@@ -14,7 +14,7 @@ def setup_client(tmp_path):
key = Fernet.generate_key() key = Fernet.generate_key()
enc_mgr = EncryptionManager(key, tmp_path) enc_mgr = EncryptionManager(key, tmp_path)
with patch("nostr.client.ClientPool"), patch( with patch("nostr.client.RelayManager"), patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"), patch.object( ), patch.object(NostrClient, "initialize_client_pool"), patch.object(
enc_mgr, "decrypt_parent_seed", return_value="seed" enc_mgr, "decrypt_parent_seed", return_value="seed"
@@ -25,12 +25,12 @@ def setup_client(tmp_path):
class FakeEvent: class FakeEvent:
KIND_TEXT_NOTE = 1 KIND_TEXT_NOTE = 1
KIND_ENCRYPT = 2 KIND_ENCRYPT = 4
def __init__(self, kind, content, pub_key): def __init__(self, kind, content, pubkey):
self.kind = kind self.kind = kind
self.content = content self.content = content
self.pub_key = pub_key self.pubkey = pubkey
self.id = "id" self.id = "id"
def sign(self, _): def sign(self, _):

View File

@@ -27,7 +27,7 @@ def test_backup_and_publish_to_nostr():
with patch( with patch(
"nostr.client.NostrClient.publish_json_to_nostr", return_value=True "nostr.client.NostrClient.publish_json_to_nostr", return_value=True
) as mock_publish, patch("nostr.client.ClientPool"), patch( ) as mock_publish, patch("nostr.client.RelayManager"), patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
), patch.object( ), patch.object(
NostrClient, "initialize_client_pool" NostrClient, "initialize_client_pool"