Switch to pynostr

This commit is contained in:
thePR0M3TH3AN
2025-06-30 20:55:46 -04:00
parent 87a493b845
commit f60eaa4a1e
5 changed files with 164 additions and 662 deletions

View File

@@ -1,55 +1,23 @@
import os
import sys
import logging
import traceback
import json
import time
import asyncio
import base64
import hashlib
import asyncio
import concurrent.futures
from typing import List, Optional, Callable
from pathlib import Path
try:
from monstr.client.client import ClientPool
from monstr.encrypt import Keys, NIP4Encrypt
from monstr.event.event import Event
except ImportError: # Fallback placeholders when monstr is unavailable
NIP4Encrypt = None
Event = None
class ClientPool: # minimal stub for tests when monstr is absent
def __init__(self, relays):
self.relays = relays
self.connected = True
async def run(self):
pass
def publish(self, event):
pass
def subscribe(self, handlers=None, filters=None, sub_id=None):
pass
def unsubscribe(self, sub_id):
pass
from .coincurve_keys import Keys
import threading
import json
import logging
import time
import uuid
from pathlib import Path
from typing import Callable, List, Optional
from pynostr.websocket_relay_manager import WebSocketRelayManager
from pynostr.event import Event, EventKind
from pynostr.encrypted_dm import EncryptedDirectMessage
from .key_manager import KeyManager
from password_manager.encryption import EncryptionManager
from .event_handler import EventHandler
from utils.file_lock import exclusive_lock
# Get the logger for this module
logger = logging.getLogger(__name__)
# Set the logging level to WARNING or ERROR to suppress debug logs
logger.setLevel(logging.WARNING)
DEFAULT_RELAYS = [
@@ -58,674 +26,206 @@ DEFAULT_RELAYS = [
"wss://relay.primal.net",
]
# nostr/client.py
# src/nostr/client.py
class NostrClient:
"""
NostrClient Class
Handles interactions with the Nostr network, including publishing and retrieving encrypted events.
Utilizes deterministic key derivation via BIP-85 and integrates with the monstr library for protocol operations.
"""
"""Interact with the Nostr network using pynostr."""
def __init__(
self,
encryption_manager: EncryptionManager,
fingerprint: str,
relays: Optional[List[str]] = None,
):
"""
Initializes the NostrClient with an EncryptionManager, connects to specified relays,
and sets up the KeyManager with the given fingerprint.
) -> None:
self.encryption_manager = encryption_manager
self.fingerprint = fingerprint
self.fingerprint_dir = self.encryption_manager.fingerprint_dir
self.key_manager = KeyManager(
self.encryption_manager.decrypt_parent_seed(), fingerprint
)
self.event_handler = EventHandler()
self.relays = relays if relays else DEFAULT_RELAYS
self.client_pool = None
self.subscriptions: set[str] = set()
self.initialize_client_pool()
:param encryption_manager: An instance of EncryptionManager for handling encryption/decryption.
:param fingerprint: The fingerprint to differentiate key derivations for unique identities.
:param relays: (Optional) A list of relay URLs to connect to. Defaults to predefined relays.
"""
try:
# Assign the encryption manager and fingerprint
self.encryption_manager = encryption_manager
self.fingerprint = fingerprint # Track the fingerprint
self.fingerprint_dir = (
self.encryption_manager.fingerprint_dir
) # If needed to manage directories
def initialize_client_pool(self) -> None:
"""Create the relay manager and connect to configured relays."""
self.client_pool = WebSocketRelayManager()
for relay in self.relays:
self.client_pool.add_relay(relay)
# Initialize KeyManager with the decrypted parent seed and the provided fingerprint
self.key_manager = KeyManager(
self.encryption_manager.decrypt_parent_seed(), self.fingerprint
)
async def publish_event_async(self, event: Event) -> None:
logger.debug("Publishing event %s", event.id)
self.client_pool.publish_event(event)
# Initialize event handler and client pool
self.event_handler = EventHandler()
self.relays = relays if relays else DEFAULT_RELAYS
self.client_pool = ClientPool(self.relays)
self.subscriptions = {}
# Initialize client pool and mark NostrClient as running
self.initialize_client_pool()
logger.info("NostrClient initialized successfully.")
# For shutdown handling
self.is_shutting_down = False
self._shutdown_event = asyncio.Event()
except Exception as e:
logger.error(f"Initialization failed: {e}")
logger.error(traceback.format_exc())
print(f"Error: Initialization failed: {e}", file=sys.stderr)
sys.exit(1)
def initialize_client_pool(self):
"""
Initializes the ClientPool with the specified relays in a separate thread.
"""
try:
logger.debug("Initializing ClientPool with relays.")
if ClientPool is None:
raise ImportError("monstr library is required for ClientPool")
self.client_pool = ClientPool(self.relays)
# Start the ClientPool in a separate thread
self.loop_thread = threading.Thread(target=self.run_event_loop, daemon=True)
self.loop_thread.start()
# Wait until the ClientPool is connected to all relays
self.wait_for_connection()
logger.info("ClientPool connected to all relays.")
except Exception as e:
logger.error(f"Failed to initialize ClientPool: {e}")
logger.error(traceback.format_exc())
print(f"Error: Failed to initialize ClientPool: {e}", file=sys.stderr)
sys.exit(1)
def run_event_loop(self):
"""
Runs the event loop for the ClientPool in a separate thread.
"""
try:
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.create_task(self.client_pool.run())
self.loop.run_forever()
except asyncio.CancelledError:
logger.debug("Event loop received cancellation.")
except Exception as e:
logger.error(f"Error running event loop in thread: {e}")
logger.error(traceback.format_exc())
print(
f"Error: Event loop in ClientPool thread encountered an issue: {e}",
file=sys.stderr,
)
finally:
if not self.loop.is_closed():
logger.debug("Closing the event loop.")
self.loop.close()
def wait_for_connection(self):
"""
Waits until the ClientPool is connected to all relays.
"""
try:
while not self.client_pool.connected:
time.sleep(0.1)
except Exception as e:
logger.error(f"Error while waiting for ClientPool to connect: {e}")
logger.error(traceback.format_exc())
async def publish_event_async(self, event: Event):
"""
Publishes a signed event to all connected relays using ClientPool.
:param event: The signed Event object to publish.
"""
try:
logger.debug(f"Publishing event: {event.serialize()}")
self.client_pool.publish(event)
logger.info(f"Event published with ID: {event.id}")
logger.debug(f"Finished publishing event: {event.id}")
except Exception as e:
logger.error(f"Failed to publish event: {e}")
logger.error(traceback.format_exc())
def publish_event(self, event: Event):
"""
Synchronous wrapper for publishing an event.
:param event: The signed Event object to publish.
"""
try:
logger.debug(f"Submitting publish_event_async for event ID: {event.id}")
future = asyncio.run_coroutine_threadsafe(
self.publish_event_async(event), self.loop
)
# Wait for the future to complete
future.result(timeout=5) # Adjust the timeout as needed
except Exception as e:
logger.error(f"Error in publish_event: {e}")
print(f"Error: Failed to publish event: {e}", file=sys.stderr)
def publish_event(self, event: Event) -> None:
self.client_pool.publish_event(event)
async def subscribe_async(
self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None]
):
"""
Subscribes to events based on the provided filters using ClientPool.
self,
filters: List[dict],
handler: Callable[[WebSocketRelayManager, str, Event], None],
timeout: float = 2.0,
) -> None:
sub_id = str(uuid.uuid4())
from pynostr.filters import FiltersList
:param filters: A list of filter dictionaries.
:param handler: A callback function to handle incoming events.
"""
filter_list = FiltersList.from_json_array(filters)
self.client_pool.add_subscription_on_all_relays(sub_id, filter_list)
self.subscriptions.add(sub_id)
end = asyncio.get_event_loop().time() + timeout
try:
sub_id = str(uuid.uuid4())
self.client_pool.subscribe(handlers=handler, filters=filters, sub_id=sub_id)
logger.info(f"Subscribed to events with subscription ID: {sub_id}")
self.subscriptions[sub_id] = True
except Exception as e:
logger.error(f"Failed to subscribe: {e}")
logger.error(traceback.format_exc())
print(f"Error: Failed to subscribe: {e}", file=sys.stderr)
while asyncio.get_event_loop().time() < end:
while self.client_pool.message_pool.has_events():
msg = self.client_pool.message_pool.get_event()
if msg.subscription_id == sub_id:
handler(self.client_pool, sub_id, msg.event)
await asyncio.sleep(0.1)
finally:
self.client_pool.close_subscription_on_all_relays(sub_id)
self.subscriptions.discard(sub_id)
def subscribe(
self, filters: List[dict], handler: Callable[[ClientPool, str, Event], None]
):
"""
Synchronous wrapper for subscribing to events.
:param filters: A list of filter dictionaries.
:param handler: A callback function to handle incoming events.
"""
try:
asyncio.run_coroutine_threadsafe(
self.subscribe_async(filters, handler), self.loop
)
except Exception as e:
logger.error(f"Error in subscribe: {e}")
print(f"Error: Failed to subscribe: {e}", file=sys.stderr)
self,
filters: List[dict],
handler: Callable[[WebSocketRelayManager, str, Event], None],
timeout: float = 2.0,
) -> None:
asyncio.run(self.subscribe_async(filters, handler, timeout))
async def retrieve_json_from_nostr_async(self) -> Optional[str]:
"""
Retrieves the latest encrypted JSON event from Nostr.
filters = [
{
"authors": [self.key_manager.keys.public_key_hex()],
"kinds": [EventKind.TEXT_NOTE, EventKind.ENCRYPTED_DIRECT_MESSAGE],
"limit": 1,
}
]
events: list[Event] = []
:return: The encrypted JSON data as a Base64-encoded string, or None if retrieval fails.
"""
try:
filters = [
{
"authors": [self.key_manager.keys.public_key_hex()],
"kinds": [Event.KIND_TEXT_NOTE, Event.KIND_ENCRYPT],
"limit": 1,
}
]
async def handler(_client, _sid, evt: Event):
events.append(evt)
events = []
await self.subscribe_async(filters, handler)
def my_handler(the_client, sub_id, evt: Event):
logger.debug(f"Received event: {evt.serialize()}")
events.append(evt)
await self.subscribe_async(filters=filters, handler=my_handler)
await asyncio.sleep(2) # Adjust the sleep time as needed
# Unsubscribe from all subscriptions
for sub_id in list(self.subscriptions.keys()):
self.client_pool.unsubscribe(sub_id)
del self.subscriptions[sub_id]
logger.debug(f"Unsubscribed from sub_id {sub_id}")
if events:
event = events[0]
content_base64 = event.content
if event.kind == Event.KIND_ENCRYPT:
if NIP4Encrypt is None:
raise ImportError("monstr library required for NIP4 encryption")
nip4_encrypt = NIP4Encrypt(self.key_manager.keys)
content_base64 = nip4_encrypt.decrypt_message(
event.content, event.pub_key
)
# Return the Base64-encoded content as a string
logger.debug("Encrypted JSON data retrieved successfully.")
return content_base64
else:
logger.warning("No events found matching the filters.")
print("No events found matching the filters.", file=sys.stderr)
return None
except Exception as e:
logger.error(f"Failed to retrieve JSON from Nostr: {e}")
logger.error(traceback.format_exc())
print(f"Error: Failed to retrieve JSON from Nostr: {e}", file=sys.stderr)
if not events:
return None
def retrieve_json_from_nostr(self) -> Optional[bytes]:
"""
Public method to retrieve encrypted JSON from Nostr.
:return: The encrypted JSON data as bytes, or None if retrieval fails.
"""
try:
future = asyncio.run_coroutine_threadsafe(
self.retrieve_json_from_nostr_async(), self.loop
event = events[0]
content_base64 = event.content
if event.kind == EventKind.ENCRYPTED_DIRECT_MESSAGE:
dm = EncryptedDirectMessage.from_event(event)
dm.decrypt(
self.key_manager.keys.private_key_hex(), public_key_hex=dm.pubkey
)
return future.result(timeout=10)
except concurrent.futures.TimeoutError:
logger.error("Timeout occurred while retrieving JSON from Nostr.")
print(
"Error: Timeout occurred while retrieving JSON from Nostr.",
file=sys.stderr,
)
return None
except Exception as e:
logger.error(f"Error in retrieve_json_from_nostr: {e}")
logger.error(traceback.format_exc())
print(f"Error: Failed to retrieve JSON from Nostr: {e}", "red")
return None
content_base64 = dm.cleartext_content
return content_base64
async def do_post_async(self, text: str):
"""
Creates and publishes a text note event.
def retrieve_json_from_nostr(self) -> Optional[str]:
return asyncio.run(self.retrieve_json_from_nostr_async())
:param text: The content of the text note.
"""
try:
event = Event(
kind=Event.KIND_TEXT_NOTE,
content=text,
pub_key=self.key_manager.keys.public_key_hex(),
)
event.created_at = int(time.time())
event.sign(self.key_manager.keys.private_key_hex())
logger.debug(f"Event data: {event.serialize()}")
await self.publish_event_async(event)
logger.debug("Finished do_post_async")
except Exception as e:
logger.error(f"An error occurred during publishing: {e}", exc_info=True)
print(f"Error: An error occurred during publishing: {e}", file=sys.stderr)
async def do_post_async(self, text: str) -> None:
event = Event(kind=EventKind.TEXT_NOTE, content=text)
event.pubkey = self.key_manager.keys.public_key_hex()
event.created_at = int(time.time())
event.sign(self.key_manager.keys.private_key_hex())
await self.publish_event_async(event)
async def subscribe_feed_async(
self, handler: Callable[[ClientPool, str, Event], None]
):
"""
Subscribes to the feed of the client's own pubkey.
self, handler: Callable[[WebSocketRelayManager, str, Event], None]
) -> None:
filters = [
{
"authors": [self.key_manager.keys.public_key_hex()],
"kinds": [EventKind.TEXT_NOTE, EventKind.ENCRYPTED_DIRECT_MESSAGE],
"limit": 100,
}
]
await self.subscribe_async(filters, handler)
:param handler: A callback function to handle incoming events.
"""
try:
filters = [
{
"authors": [self.key_manager.keys.public_key_hex()],
"kinds": [Event.KIND_TEXT_NOTE, Event.KIND_ENCRYPT],
"limit": 100,
}
]
async def publish_and_subscribe_async(self, text: str) -> None:
await asyncio.gather(
self.do_post_async(text),
self.subscribe_feed_async(self.event_handler.handle_new_event),
)
await self.subscribe_async(filters=filters, handler=handler)
logger.info("Subscribed to your feed.")
# Removed the infinite loop to prevent blocking
except Exception as e:
logger.error(f"An error occurred during subscription: {e}", exc_info=True)
print(f"Error: An error occurred during subscription: {e}", file=sys.stderr)
async def publish_and_subscribe_async(self, text: str):
"""
Publishes a text note and subscribes to the feed concurrently.
:param text: The content of the text note to publish.
"""
try:
await asyncio.gather(
self.do_post_async(text),
self.subscribe_feed_async(self.event_handler.handle_new_event),
)
except Exception as e:
logger.error(
f"An error occurred in publish_and_subscribe_async: {e}", exc_info=True
)
print(
f"Error: An error occurred in publish and subscribe: {e}",
file=sys.stderr,
)
def publish_and_subscribe(self, text: str):
"""
Public method to publish a text note and subscribe to the feed.
:param text: The content of the text note to publish.
"""
try:
asyncio.run_coroutine_threadsafe(
self.publish_and_subscribe_async(text), self.loop
)
except Exception as e:
logger.error(f"Error in publish_and_subscribe: {e}", exc_info=True)
print(f"Error: Failed to publish and subscribe: {e}", file=sys.stderr)
def publish_and_subscribe(self, text: str) -> None:
asyncio.run(self.publish_and_subscribe_async(text))
def decrypt_and_save_index_from_nostr(self, encrypted_data: bytes) -> None:
"""
Decrypts the encrypted data retrieved from Nostr and updates the local index file.
:param encrypted_data: The encrypted data retrieved from Nostr.
"""
try:
decrypted_data = self.encryption_manager.decrypt_data(encrypted_data)
data = json.loads(decrypted_data.decode("utf-8"))
self.save_json_data(data)
self.update_checksum()
logger.info("Index file updated from Nostr successfully.")
print(colored("Index file updated from Nostr successfully.", "green"))
except Exception as e:
logger.error(f"Failed to decrypt and save data from Nostr: {e}")
logger.error(traceback.format_exc())
print(
colored(
f"Error: Failed to decrypt and save data from Nostr: {e}", "red"
)
)
decrypted_data = self.encryption_manager.decrypt_data(encrypted_data)
data = json.loads(decrypted_data.decode("utf-8"))
self.save_json_data(data)
self.update_checksum()
def save_json_data(self, data: dict) -> None:
"""
Saves the JSON data to the index file in an encrypted format.
:param data: The JSON data to save.
"""
try:
encrypted_data = self.encryption_manager.encrypt_data(
json.dumps(data).encode("utf-8")
)
index_file_path = self.fingerprint_dir / "seedpass_passwords_db.json.enc"
with exclusive_lock(index_file_path):
with open(index_file_path, "wb") as f:
f.write(encrypted_data)
logger.debug(f"Encrypted data saved to {index_file_path}.")
print(colored(f"Encrypted data saved to '{index_file_path}'.", "green"))
except Exception as e:
logger.error(f"Failed to save encrypted data: {e}")
logger.error(traceback.format_exc())
print(colored(f"Error: Failed to save encrypted data: {e}", "red"))
raise
encrypted_data = self.encryption_manager.encrypt_data(
json.dumps(data).encode("utf-8")
)
index_file_path = self.fingerprint_dir / "seedpass_passwords_db.json.enc"
with exclusive_lock(index_file_path):
with open(index_file_path, "wb") as f:
f.write(encrypted_data)
def update_checksum(self) -> None:
"""
Updates the checksum file for the password database.
"""
try:
index_file_path = self.fingerprint_dir / "seedpass_passwords_db.json.enc"
decrypted_data = self.decrypt_data_from_file(index_file_path)
content = decrypted_data.decode("utf-8")
logger.debug("Calculating checksum of the updated file content.")
checksum = hashlib.sha256(content.encode("utf-8")).hexdigest()
logger.debug(f"New checksum: {checksum}")
checksum_file = self.fingerprint_dir / "seedpass_passwords_db_checksum.txt"
with exclusive_lock(checksum_file):
with open(checksum_file, "w") as f:
f.write(checksum)
os.chmod(checksum_file, 0o600)
logger.debug(
f"Checksum for '{index_file_path}' updated and written to '{checksum_file}'."
)
print(colored(f"Checksum for '{index_file_path}' updated.", "green"))
except Exception as e:
logger.error(f"Failed to update checksum: {e}")
logger.error(traceback.format_exc())
print(colored(f"Error: Failed to update checksum: {e}", "red"))
index_file_path = self.fingerprint_dir / "seedpass_passwords_db.json.enc"
decrypted_data = self.decrypt_data_from_file(index_file_path)
content = decrypted_data.decode("utf-8")
checksum = hashlib.sha256(content.encode("utf-8")).hexdigest()
checksum_file = self.fingerprint_dir / "seedpass_passwords_db_checksum.txt"
with exclusive_lock(checksum_file):
with open(checksum_file, "w") as f:
f.write(checksum)
checksum_file.chmod(0o600)
def decrypt_data_from_file(self, file_path: Path) -> bytes:
"""
Decrypts data directly from a file.
:param file_path: Path to the encrypted file as a Path object.
:return: Decrypted data as bytes.
"""
try:
with exclusive_lock(file_path):
with open(file_path, "rb") as f:
encrypted_data = f.read()
decrypted_data = self.encryption_manager.decrypt_data(encrypted_data)
logger.debug(f"Data decrypted from file '{file_path}'.")
return decrypted_data
except Exception as e:
logger.error(f"Failed to decrypt data from file '{file_path}': {e}")
logger.error(traceback.format_exc())
print(
colored(
f"Error: Failed to decrypt data from file '{file_path}': {e}", "red"
)
)
raise
with exclusive_lock(file_path):
with open(file_path, "rb") as f:
encrypted_data = f.read()
return self.encryption_manager.decrypt_data(encrypted_data)
def publish_json_to_nostr(
self, encrypted_json: bytes, to_pubkey: str | None = None
) -> bool:
"""Post encrypted JSON to Nostr.
Parameters
----------
encrypted_json:
The encrypted JSON data to send.
to_pubkey:
Optional recipient public key. If provided the message will be NIP-4
encrypted for that key.
Returns
-------
bool
``True`` when the event is successfully published, ``False`` on
failure.
"""
try:
encrypted_json_b64 = base64.b64encode(encrypted_json).decode("utf-8")
logger.debug(f"Encrypted JSON (base64): {encrypted_json_b64}")
event = Event(
kind=Event.KIND_TEXT_NOTE,
content=encrypted_json_b64,
pub_key=self.key_manager.keys.public_key_hex(),
)
event.created_at = int(time.time())
content = base64.b64encode(encrypted_json).decode("utf-8")
if to_pubkey:
if NIP4Encrypt is None:
raise ImportError("monstr library required for NIP4 encryption")
nip4_encrypt = NIP4Encrypt(self.key_manager.keys)
event.content = nip4_encrypt.encrypt_message(event.content, to_pubkey)
event.kind = Event.KIND_ENCRYPT
logger.debug(f"Encrypted event content: {event.content}")
dm = EncryptedDirectMessage()
dm.encrypt(
private_key_hex=self.key_manager.keys.private_key_hex(),
cleartext_content=content,
recipient_pubkey=to_pubkey,
)
event = dm.to_event()
else:
event = Event(kind=EventKind.TEXT_NOTE, content=content)
event.pubkey = self.key_manager.keys.public_key_hex()
event.created_at = int(time.time())
event.sign(self.key_manager.keys.private_key_hex())
logger.debug("Event created and signed")
self.publish_event(event)
logger.debug("Event published")
return True
except Exception as e:
logger.error(f"Failed to publish JSON to Nostr: {e}")
logger.error(traceback.format_exc())
print(f"Error: Failed to publish JSON to Nostr: {e}", file=sys.stderr)
except Exception as e: # pragma: no cover - defensive
logger.error("Failed to publish JSON to Nostr: %s", e)
return False
def retrieve_json_from_nostr_sync(self) -> Optional[bytes]:
"""
Retrieves encrypted data from Nostr and Base64-decodes it.
Returns:
Optional[bytes]: The encrypted data as bytes if successful, None otherwise.
"""
try:
future = asyncio.run_coroutine_threadsafe(
self.retrieve_json_from_nostr_async(), self.loop
)
content_base64 = future.result(timeout=10)
if not content_base64:
logger.debug("No data retrieved from Nostr.")
return None
# Base64-decode the content
encrypted_data = base64.urlsafe_b64decode(content_base64.encode("utf-8"))
logger.debug(
"Encrypted data retrieved and Base64-decoded successfully from Nostr."
)
return encrypted_data
except concurrent.futures.TimeoutError:
logger.error("Timeout occurred while retrieving JSON from Nostr.")
print(
"Error: Timeout occurred while retrieving JSON from Nostr.",
file=sys.stderr,
)
return None
except Exception as e:
logger.error(f"Error in retrieve_json_from_nostr: {e}")
logger.error(traceback.format_exc())
print(f"Error: Failed to retrieve JSON from Nostr: {e}", "red")
return None
content = self.retrieve_json_from_nostr()
if content:
return base64.urlsafe_b64decode(content.encode("utf-8"))
return None
def decrypt_and_save_index_from_nostr_public(self, encrypted_data: bytes) -> None:
"""
Public method to decrypt and save data from Nostr.
self.decrypt_and_save_index_from_nostr(encrypted_data)
:param encrypted_data: The encrypted data retrieved from Nostr.
"""
try:
self.decrypt_and_save_index_from_nostr(encrypted_data)
except Exception as e:
logger.error(f"Failed to decrypt and save index from Nostr: {e}")
print(f"Error: Failed to decrypt and save index from Nostr: {e}", "red")
async def close_client_pool_async(self) -> None:
self.client_pool.close_all_relay_connections()
async def close_client_pool_async(self):
"""
Closes the ClientPool gracefully by canceling all pending tasks and stopping the event loop.
"""
if self.is_shutting_down:
logger.debug("Shutdown already in progress.")
return
def close_client_pool(self) -> None:
self.client_pool.close_all_relay_connections()
try:
self.is_shutting_down = True
logger.debug("Initiating ClientPool shutdown.")
# Set the shutdown event
self._shutdown_event.set()
# Cancel all subscriptions
for sub_id in list(self.subscriptions.keys()):
try:
self.client_pool.unsubscribe(sub_id)
del self.subscriptions[sub_id]
logger.debug(f"Unsubscribed from sub_id {sub_id}")
except Exception as e:
logger.warning(f"Error unsubscribing from {sub_id}: {e}")
# Close all WebSocket connections
if hasattr(self.client_pool, "clients"):
tasks = [
self.safe_close_connection(client)
for client in self.client_pool.clients
]
await asyncio.gather(*tasks, return_exceptions=True)
# Gather and cancel all tasks
current_task = asyncio.current_task()
tasks = [
task
for task in asyncio.all_tasks(loop=self.loop)
if task != current_task and not task.done()
]
if tasks:
logger.debug(f"Cancelling {len(tasks)} pending tasks.")
for task in tasks:
task.cancel()
# Wait for all tasks to be cancelled with a timeout
try:
await asyncio.wait_for(
asyncio.gather(*tasks, return_exceptions=True), timeout=5
)
except asyncio.TimeoutError:
logger.warning("Timeout waiting for tasks to cancel")
logger.debug("Stopping the event loop.")
self.loop.stop()
logger.info("Event loop stopped successfully.")
except Exception as e:
logger.error(f"Error during async shutdown: {e}")
logger.error(traceback.format_exc())
finally:
self.is_shutting_down = False
def close_client_pool(self):
"""
Public method to close the ClientPool gracefully.
"""
if self.is_shutting_down:
logger.debug("Shutdown already in progress. Skipping redundant shutdown.")
return
try:
# Schedule the coroutine to close the client pool
future = asyncio.run_coroutine_threadsafe(
self.close_client_pool_async(), self.loop
)
# Wait for the coroutine to finish with a timeout
try:
future.result(timeout=10)
except concurrent.futures.TimeoutError:
logger.warning("Initial shutdown attempt timed out, forcing cleanup...")
# Additional cleanup regardless of timeout
try:
self.loop.call_soon_threadsafe(self.loop.stop)
# Give a short grace period for the loop to stop
time.sleep(0.5)
if self.loop.is_running():
logger.warning("Loop still running after stop, closing forcefully")
self.loop.call_soon_threadsafe(self.loop.close)
# Wait for the thread with a reasonable timeout
if self.loop_thread.is_alive():
self.loop_thread.join(timeout=5)
if self.loop_thread.is_alive():
logger.warning(
"Thread still alive after join, may need to be force-killed"
)
except Exception as cleanup_error:
logger.error(f"Error during final cleanup: {cleanup_error}")
logger.info("ClientPool shutdown complete")
except Exception as e:
logger.error(f"Error in close_client_pool: {e}")
logger.error(traceback.format_exc())
finally:
self.is_shutting_down = False
async def safe_close_connection(self, client):
async def safe_close_connection(self, client): # pragma: no cover - compatibility
try:
await client.close_connection()
logger.debug(f"Closed connection to relay: {client.url}")
except AttributeError:
logger.warning(
f"Client object has no attribute 'close_connection'. Skipping closure for {client.url}."
)
except Exception as e:
logger.warning(f"Error closing connection to {client.url}: {e}")
except Exception:
pass

View File

@@ -11,4 +11,6 @@ bip85
pytest>=7.0
pytest-cov
portalocker>=2.8
pynostr>=0.6.2
websocket-client

View File

@@ -16,11 +16,11 @@ def test_nostr_client_uses_custom_relays():
enc_mgr = EncryptionManager(key, Path(tmpdir))
custom_relays = ["wss://relay1", "wss://relay2"]
with patch("nostr.client.ClientPool") as MockPool, patch(
with patch("nostr.client.WebSocketRelayManager") as MockPool, patch(
"nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"):
):
with patch.object(enc_mgr, "decrypt_parent_seed", return_value="seed"):
client = NostrClient(enc_mgr, "fp", relays=custom_relays)
MockPool.assert_called_with(custom_relays)
MockPool.assert_called_with()
assert client.relays == custom_relays

View File

@@ -14,7 +14,7 @@ def setup_client(tmp_path):
key = Fernet.generate_key()
enc_mgr = EncryptionManager(key, tmp_path)
with patch("nostr.client.ClientPool"), patch(
with patch("nostr.client.WebSocketRelayManager"), patch(
"nostr.client.KeyManager"
), patch.object(NostrClient, "initialize_client_pool"), patch.object(
enc_mgr, "decrypt_parent_seed", return_value="seed"
@@ -27,10 +27,10 @@ class FakeEvent:
KIND_TEXT_NOTE = 1
KIND_ENCRYPT = 2
def __init__(self, kind, content, pub_key):
def __init__(self, kind, content, pub_key=None):
self.kind = kind
self.content = content
self.pub_key = pub_key
self.pubkey = pub_key
self.id = "id"
def sign(self, _):

View File

@@ -27,7 +27,7 @@ def test_backup_and_publish_to_nostr():
with patch(
"nostr.client.NostrClient.publish_json_to_nostr", return_value=True
) as mock_publish, patch("nostr.client.ClientPool"), patch(
) as mock_publish, patch("nostr.client.WebSocketRelayManager"), patch(
"nostr.client.KeyManager"
), patch.object(
NostrClient, "initialize_client_pool"