Revert "Merge pull request #66 from PR0M3TH3AN/codex/replace-monstr-with-pynostr-in-client"

This reverts commit c79dd805e3, reversing
changes made to c05f19d3a4.
This commit is contained in:
thePR0M3TH3AN
2025-06-30 16:07:10 -04:00
parent c79dd805e3
commit 87a493b845
5 changed files with 124 additions and 106 deletions

View File

@@ -11,10 +11,32 @@ import concurrent.futures
from typing import List, Optional, Callable from typing import List, Optional, Callable
from pathlib import Path from pathlib import Path
from pynostr.relay_manager import RelayManager try:
from pynostr.event import Event, EventKind from monstr.client.client import ClientPool
from pynostr.encrypted_dm import EncryptedDirectMessage from monstr.encrypt import Keys, NIP4Encrypt
from .coincurve_keys import Keys from monstr.event.event import Event
except ImportError: # Fallback placeholders when monstr is unavailable
NIP4Encrypt = None
Event = None
class ClientPool: # minimal stub for tests when monstr is absent
def __init__(self, relays):
self.relays = relays
self.connected = True
async def run(self):
pass
def publish(self, event):
pass
def subscribe(self, handlers=None, filters=None, sub_id=None):
pass
def unsubscribe(self, sub_id):
pass
from .coincurve_keys import Keys
import threading import threading
import uuid import uuid
@@ -30,11 +52,6 @@ logger = logging.getLogger(__name__)
# Set the logging level to WARNING or ERROR to suppress debug logs # Set the logging level to WARNING or ERROR to suppress debug logs
logger.setLevel(logging.WARNING) logger.setLevel(logging.WARNING)
# Map legacy constants used in tests to pynostr enums
Event.KIND_TEXT_NOTE = EventKind.TEXT_NOTE
Event.KIND_ENCRYPT = EventKind.ENCRYPTED_DIRECT_MESSAGE
Event.KIND_ENCRYPTED_DIRECT_MESSAGE = EventKind.ENCRYPTED_DIRECT_MESSAGE
DEFAULT_RELAYS = [ DEFAULT_RELAYS = [
"wss://relay.snort.social", "wss://relay.snort.social",
"wss://nostr.oxtr.dev", "wss://nostr.oxtr.dev",
@@ -84,9 +101,7 @@ class NostrClient:
# Initialize event handler and client pool # Initialize event handler and client pool
self.event_handler = EventHandler() self.event_handler = EventHandler()
self.relays = relays if relays else DEFAULT_RELAYS self.relays = relays if relays else DEFAULT_RELAYS
self.client_pool = RelayManager() self.client_pool = ClientPool(self.relays)
for url in self.relays:
self.client_pool.add_relay(url)
self.subscriptions = {} self.subscriptions = {}
# Initialize client pool and mark NostrClient as running # Initialize client pool and mark NostrClient as running
@@ -105,30 +120,36 @@ class NostrClient:
def initialize_client_pool(self): def initialize_client_pool(self):
""" """
Initializes the RelayManager with the specified relays in a separate thread. Initializes the ClientPool with the specified relays in a separate thread.
""" """
try: try:
logger.debug("Initializing RelayManager with relays.") logger.debug("Initializing ClientPool with relays.")
if ClientPool is None:
raise ImportError("monstr library is required for ClientPool")
self.client_pool = ClientPool(self.relays)
# Start the ClientPool in a separate thread
self.loop_thread = threading.Thread(target=self.run_event_loop, daemon=True) self.loop_thread = threading.Thread(target=self.run_event_loop, daemon=True)
self.loop_thread.start() self.loop_thread.start()
# Wait briefly for connection establishment but don't block forever # Wait until the ClientPool is connected to all relays
self.wait_for_connection(timeout=5) self.wait_for_connection()
logger.info("RelayManager connected to all relays.") logger.info("ClientPool connected to all relays.")
except Exception as e: except Exception as e:
logger.error(f"Failed to initialize RelayManager: {e}") logger.error(f"Failed to initialize ClientPool: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
print(f"Error: Failed to initialize RelayManager: {e}", file=sys.stderr) print(f"Error: Failed to initialize ClientPool: {e}", file=sys.stderr)
sys.exit(1) sys.exit(1)
def run_event_loop(self): def run_event_loop(self):
""" """
Runs the event loop used for background tasks. Runs the event loop for the ClientPool in a separate thread.
""" """
try: try:
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop) asyncio.set_event_loop(self.loop)
self.loop.create_task(self.client_pool.run())
self.loop.run_forever() self.loop.run_forever()
except asyncio.CancelledError: except asyncio.CancelledError:
logger.debug("Event loop received cancellation.") logger.debug("Event loop received cancellation.")
@@ -136,63 +157,78 @@ class NostrClient:
logger.error(f"Error running event loop in thread: {e}") logger.error(f"Error running event loop in thread: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
print( print(
f"Error: Event loop thread encountered an issue: {e}", f"Error: Event loop in ClientPool thread encountered an issue: {e}",
file=sys.stderr, file=sys.stderr,
) )
finally: finally:
pass if not self.loop.is_closed():
logger.debug("Closing the event loop.")
self.loop.close()
def wait_for_connection(self, timeout: float = 5.0): def wait_for_connection(self):
"""Wait until all relays report connected or until *timeout* seconds. """
Waits until the ClientPool is connected to all relays.
This prevents the client from blocking indefinitely if a relay is
unreachable. The method simply returns when the timeout is hit.
""" """
start = time.time()
try: try:
while self.client_pool.connection_statuses and not all( while not self.client_pool.connected:
self.client_pool.connection_statuses.values()
):
if time.time() - start > timeout:
logger.warning("Timeout waiting for RelayManager to connect")
break
time.sleep(0.1) time.sleep(0.1)
except Exception as e: except Exception as e:
logger.error(f"Error while waiting for RelayManager to connect: {e}") logger.error(f"Error while waiting for ClientPool to connect: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
def publish_event(self, event: Event): async def publish_event_async(self, event: Event):
"""Publish a signed event to all connected relays.""" """
Publishes a signed event to all connected relays using ClientPool.
:param event: The signed Event object to publish.
"""
try: try:
logger.debug(f"Publishing event: {event.to_dict()}") logger.debug(f"Publishing event: {event.serialize()}")
self.client_pool.publish_event(event) self.client_pool.publish(event)
logger.info(f"Event published with ID: {event.id}") logger.info(f"Event published with ID: {event.id}")
logger.debug(f"Finished publishing event: {event.id}")
except Exception as e: except Exception as e:
logger.error(f"Failed to publish event: {e}") logger.error(f"Failed to publish event: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
def publish_event(self, event: Event):
"""
Synchronous wrapper for publishing an event.
:param event: The signed Event object to publish.
"""
try:
logger.debug(f"Submitting publish_event_async for event ID: {event.id}")
future = asyncio.run_coroutine_threadsafe(
self.publish_event_async(event), self.loop
)
# Wait for the future to complete
future.result(timeout=5) # Adjust the timeout as needed
except Exception as e:
logger.error(f"Error in publish_event: {e}")
print(f"Error: Failed to publish event: {e}", file=sys.stderr)
async def subscribe_async( async def subscribe_async(
self, filters: List[dict], handler: Callable[[RelayManager, str, Event], None] self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None]
): ):
""" """
Subscribes to events based on the provided filters using RelayManager. Subscribes to events based on the provided filters using ClientPool.
:param filters: A list of filter dictionaries. :param filters: A list of filter dictionaries.
:param handler: A callback function to handle incoming events. :param handler: A callback function to handle incoming events.
""" """
try: try:
sub_id = str(uuid.uuid4()) sub_id = str(uuid.uuid4())
# Placeholder implementation for tests. Real implementation would use self.client_pool.subscribe(handlers=handler, filters=filters, sub_id=sub_id)
# RelayManager.add_subscription_on_all_relays
self.subscriptions[sub_id] = True
logger.info(f"Subscribed to events with subscription ID: {sub_id}") logger.info(f"Subscribed to events with subscription ID: {sub_id}")
self.subscriptions[sub_id] = True
except Exception as e: except Exception as e:
logger.error(f"Failed to subscribe: {e}") logger.error(f"Failed to subscribe: {e}")
logger.error(traceback.format_exc()) logger.error(traceback.format_exc())
print(f"Error: Failed to subscribe: {e}", file=sys.stderr) print(f"Error: Failed to subscribe: {e}", file=sys.stderr)
def subscribe( def subscribe(
self, filters: List[dict], handler: Callable[[RelayManager, str, Event], None] self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None]
): ):
""" """
Synchronous wrapper for subscribing to events. Synchronous wrapper for subscribing to events.
@@ -235,8 +271,7 @@ class NostrClient:
# Unsubscribe from all subscriptions # Unsubscribe from all subscriptions
for sub_id in list(self.subscriptions.keys()): for sub_id in list(self.subscriptions.keys()):
if hasattr(self.client_pool, "close_subscription_on_all_relays"): self.client_pool.unsubscribe(sub_id)
self.client_pool.close_subscription_on_all_relays(sub_id)
del self.subscriptions[sub_id] del self.subscriptions[sub_id]
logger.debug(f"Unsubscribed from sub_id {sub_id}") logger.debug(f"Unsubscribed from sub_id {sub_id}")
@@ -245,12 +280,12 @@ class NostrClient:
content_base64 = event.content content_base64 = event.content
if event.kind == Event.KIND_ENCRYPT: if event.kind == Event.KIND_ENCRYPT:
dm = EncryptedDirectMessage.from_event(event) if NIP4Encrypt is None:
dm.decrypt( raise ImportError("monstr library required for NIP4 encryption")
private_key_hex=self.key_manager.keys.private_key_hex(), nip4_encrypt = NIP4Encrypt(self.key_manager.keys)
public_key_hex=event.pubkey, content_base64 = nip4_encrypt.decrypt_message(
event.content, event.pub_key
) )
content_base64 = dm.cleartext_content
# Return the Base64-encoded content as a string # Return the Base64-encoded content as a string
logger.debug("Encrypted JSON data retrieved successfully.") logger.debug("Encrypted JSON data retrieved successfully.")
@@ -300,21 +335,21 @@ class NostrClient:
event = Event( event = Event(
kind=Event.KIND_TEXT_NOTE, kind=Event.KIND_TEXT_NOTE,
content=text, content=text,
pubkey=self.key_manager.keys.public_key_hex(), pub_key=self.key_manager.keys.public_key_hex(),
) )
event.created_at = int(time.time()) event.created_at = int(time.time())
event.sign(self.key_manager.keys.private_key_hex()) event.sign(self.key_manager.keys.private_key_hex())
logger.debug(f"Event data: {event.serialize()}") logger.debug(f"Event data: {event.serialize()}")
self.publish_event(event) await self.publish_event_async(event)
logger.debug("Finished do_post_async") logger.debug("Finished do_post_async")
except Exception as e: except Exception as e:
logger.error(f"An error occurred during publishing: {e}", exc_info=True) logger.error(f"An error occurred during publishing: {e}", exc_info=True)
print(f"Error: An error occurred during publishing: {e}", file=sys.stderr) print(f"Error: An error occurred during publishing: {e}", file=sys.stderr)
async def subscribe_feed_async( async def subscribe_feed_async(
self, handler: Callable[[RelayManager, str, Event], None] self, handler: Callable[[ClientPool, str, Event], None]
): ):
""" """
Subscribes to the feed of the client's own pubkey. Subscribes to the feed of the client's own pubkey.
@@ -497,18 +532,18 @@ class NostrClient:
event = Event( event = Event(
kind=Event.KIND_TEXT_NOTE, kind=Event.KIND_TEXT_NOTE,
content=encrypted_json_b64, content=encrypted_json_b64,
pubkey=self.key_manager.keys.public_key_hex(), pub_key=self.key_manager.keys.public_key_hex(),
) )
event.created_at = int(time.time()) event.created_at = int(time.time())
if to_pubkey: if to_pubkey:
dm = EncryptedDirectMessage( if NIP4Encrypt is None:
cleartext_content=event.content, raise ImportError("monstr library required for NIP4 encryption")
recipient_pubkey=to_pubkey, nip4_encrypt = NIP4Encrypt(self.key_manager.keys)
) event.content = nip4_encrypt.encrypt_message(event.content, to_pubkey)
dm.encrypt(self.key_manager.keys.private_key_hex()) event.kind = Event.KIND_ENCRYPT
event = dm.to_event() logger.debug(f"Encrypted event content: {event.content}")
event.sign(self.key_manager.keys.private_key_hex()) event.sign(self.key_manager.keys.private_key_hex())
logger.debug("Event created and signed") logger.debug("Event created and signed")
@@ -572,14 +607,16 @@ class NostrClient:
print(f"Error: Failed to decrypt and save index from Nostr: {e}", "red") print(f"Error: Failed to decrypt and save index from Nostr: {e}", "red")
async def close_client_pool_async(self): async def close_client_pool_async(self):
"""Closes the RelayManager gracefully by canceling all pending tasks and stopping the event loop.""" """
Closes the ClientPool gracefully by canceling all pending tasks and stopping the event loop.
"""
if self.is_shutting_down: if self.is_shutting_down:
logger.debug("Shutdown already in progress.") logger.debug("Shutdown already in progress.")
return return
try: try:
self.is_shutting_down = True self.is_shutting_down = True
logger.debug("Initiating RelayManager shutdown.") logger.debug("Initiating ClientPool shutdown.")
# Set the shutdown event # Set the shutdown event
self._shutdown_event.set() self._shutdown_event.set()
@@ -587,18 +624,17 @@ class NostrClient:
# Cancel all subscriptions # Cancel all subscriptions
for sub_id in list(self.subscriptions.keys()): for sub_id in list(self.subscriptions.keys()):
try: try:
if hasattr(self.client_pool, "close_subscription_on_all_relays"): self.client_pool.unsubscribe(sub_id)
self.client_pool.close_subscription_on_all_relays(sub_id)
del self.subscriptions[sub_id] del self.subscriptions[sub_id]
logger.debug(f"Unsubscribed from sub_id {sub_id}") logger.debug(f"Unsubscribed from sub_id {sub_id}")
except Exception as e: except Exception as e:
logger.warning(f"Error unsubscribing from {sub_id}: {e}") logger.warning(f"Error unsubscribing from {sub_id}: {e}")
# Close all WebSocket connections # Close all WebSocket connections
if hasattr(self.client_pool, "relays"): if hasattr(self.client_pool, "clients"):
tasks = [ tasks = [
self.safe_close_connection(relay) self.safe_close_connection(client)
for relay in self.client_pool.relays.values() for client in self.client_pool.clients
] ]
await asyncio.gather(*tasks, return_exceptions=True) await asyncio.gather(*tasks, return_exceptions=True)
@@ -634,7 +670,9 @@ class NostrClient:
self.is_shutting_down = False self.is_shutting_down = False
def close_client_pool(self): def close_client_pool(self):
"""Public method to close the RelayManager gracefully.""" """
Public method to close the ClientPool gracefully.
"""
if self.is_shutting_down: if self.is_shutting_down:
logger.debug("Shutdown already in progress. Skipping redundant shutdown.") logger.debug("Shutdown already in progress. Skipping redundant shutdown.")
return return
@@ -673,7 +711,7 @@ class NostrClient:
except Exception as cleanup_error: except Exception as cleanup_error:
logger.error(f"Error during final cleanup: {cleanup_error}") logger.error(f"Error during final cleanup: {cleanup_error}")
logger.info("RelayManager shutdown complete") logger.info("ClientPool shutdown complete")
except Exception as e: except Exception as e:
logger.error(f"Error in close_client_pool: {e}") logger.error(f"Error in close_client_pool: {e}")
@@ -681,9 +719,13 @@ class NostrClient:
finally: finally:
self.is_shutting_down = False self.is_shutting_down = False
async def safe_close_connection(self, relay): async def safe_close_connection(self, client):
try: try:
relay.close() await client.close_connection()
logger.debug(f"Closed connection to relay: {relay.url}") logger.debug(f"Closed connection to relay: {client.url}")
except AttributeError:
logger.warning(
f"Client object has no attribute 'close_connection'. Skipping closure for {client.url}."
)
except Exception as e: except Exception as e:
logger.warning(f"Error closing connection to {relay.url}: {e}") logger.warning(f"Error closing connection to {client.url}: {e}")

View File

@@ -11,5 +11,4 @@ bip85
pytest>=7.0 pytest>=7.0
pytest-cov pytest-cov
portalocker>=2.8 portalocker>=2.8
pynostr

View File

@@ -3,8 +3,6 @@ from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from unittest.mock import patch from unittest.mock import patch
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from types import SimpleNamespace
import time
sys.path.append(str(Path(__file__).resolve().parents[1])) sys.path.append(str(Path(__file__).resolve().parents[1]))
@@ -18,32 +16,11 @@ def test_nostr_client_uses_custom_relays():
enc_mgr = EncryptionManager(key, Path(tmpdir)) enc_mgr = EncryptionManager(key, Path(tmpdir))
custom_relays = ["wss://relay1", "wss://relay2"] custom_relays = ["wss://relay1", "wss://relay2"]
with patch("nostr.client.RelayManager") as MockManager, patch( with patch("nostr.client.ClientPool") as MockPool, patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"): ), patch.object(NostrClient, "initialize_client_pool"):
with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"): with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"):
client = NostrClient(enc_mgr, "fp", relays=custom_relays) client = NostrClient(enc_mgr, "fp", relays=custom_relays)
MockPool.assert_called_with(custom_relays)
assert client.relays == custom_relays assert client.relays == custom_relays
added = [c.args[0] for c in MockManager.return_value.add_relay.call_args_list]
assert added == custom_relays
def test_wait_for_connection_timeout():
with TemporaryDirectory() as tmpdir:
key = Fernet.generate_key()
enc_mgr = EncryptionManager(key, Path(tmpdir))
with patch.object(NostrClient, "initialize_client_pool"), patch(
"nostr.client.RelayManager"
), patch("nostr.client.KeyManager"), patch.object(
enc_mgr, "decrypt_parent_seed", return_value="seed"
):
client = NostrClient(enc_mgr, "fp")
client.client_pool = SimpleNamespace(connection_statuses={"wss://r": False})
start = time.monotonic()
client.wait_for_connection(timeout=0.2)
duration = time.monotonic() - start
assert duration >= 0.2

View File

@@ -14,7 +14,7 @@ def setup_client(tmp_path):
key = Fernet.generate_key() key = Fernet.generate_key()
enc_mgr = EncryptionManager(key, tmp_path) enc_mgr = EncryptionManager(key, tmp_path)
with patch("nostr.client.RelayManager"), patch( with patch("nostr.client.ClientPool"), patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"), patch.object( ), patch.object(NostrClient, "initialize_client_pool"), patch.object(
enc_mgr, "decrypt_parent_seed", return_value="seed" enc_mgr, "decrypt_parent_seed", return_value="seed"
@@ -25,12 +25,12 @@ def setup_client(tmp_path):
class FakeEvent: class FakeEvent:
KIND_TEXT_NOTE = 1 KIND_TEXT_NOTE = 1
KIND_ENCRYPT = 4 KIND_ENCRYPT = 2
def __init__(self, kind, content, pubkey): def __init__(self, kind, content, pub_key):
self.kind = kind self.kind = kind
self.content = content self.content = content
self.pubkey = pubkey self.pub_key = pub_key
self.id = "id" self.id = "id"
def sign(self, _): def sign(self, _):

View File

@@ -27,7 +27,7 @@ def test_backup_and_publish_to_nostr():
with patch( with patch(
"nostr.client.NostrClient.publish_json_to_nostr", return_value=True "nostr.client.NostrClient.publish_json_to_nostr", return_value=True
) as mock_publish, patch("nostr.client.RelayManager"), patch( ) as mock_publish, patch("nostr.client.ClientPool"), patch(
"nostr.client.KeyManager" "nostr.client.KeyManager"
), patch.object( ), patch.object(
NostrClient, "initialize_client_pool" NostrClient, "initialize_client_pool"