diff --git a/src/nostr/client.py b/src/nostr/client.py index a5c9d52..0ac2f6b 100644 --- a/src/nostr/client.py +++ b/src/nostr/client.py @@ -510,13 +510,16 @@ class NostrClient: self.current_manifest_id = ident return manifest, chunks - async def fetch_latest_snapshot(self) -> Tuple[Manifest, list[bytes]] | None: - """Retrieve the latest manifest and all snapshot chunks.""" - if self.offline_mode or not self.relays: - return None - await self._connect_async() + async def _fetch_manifest_with_keys( + self, keys_obj: Keys + ) -> tuple[Manifest, list[bytes]] | None: + """Attempt to retrieve the manifest and chunks using ``keys_obj``. - self.last_error = None + ``self.keys`` is updated to ``keys_obj`` so that subsequent chunk and + delta downloads use the same public key that succeeded. + """ + + self.keys = keys_obj pubkey = self.keys.public_key() identifiers = [ f"{MANIFEST_ID_PREFIX}{self.fingerprint}", @@ -560,6 +563,36 @@ class NostrClient: # manifest was found but chunks missing; do not try other identifiers return None + return None + + async def fetch_latest_snapshot(self) -> Tuple[Manifest, list[bytes]] | None: + """Retrieve the latest manifest and all snapshot chunks.""" + if self.offline_mode or not self.relays: + return None + await self._connect_async() + + self.last_error = None + + try: + primary_keys = Keys.parse(self.key_manager.keys.private_key_hex()) + except Exception: + primary_keys = self.keys + + result = await self._fetch_manifest_with_keys(primary_keys) + if result is not None: + return result + + try: + legacy_keys = self.key_manager.generate_legacy_nostr_keys() + legacy_sdk_keys = Keys.parse(legacy_keys.private_key_hex()) + except Exception as e: + self.last_error = str(e) + return None + + result = await self._fetch_manifest_with_keys(legacy_sdk_keys) + if result is not None: + return result + if self.last_error is None: self.last_error = "Snapshot not found on relays" diff --git a/src/tests/test_nostr_legacy_key_fallback.py b/src/tests/test_nostr_legacy_key_fallback.py new file mode 100644 index 0000000..a7cebb5 --- /dev/null +++ b/src/tests/test_nostr_legacy_key_fallback.py @@ -0,0 +1,86 @@ +import asyncio +import base64 +import hashlib +import json + +from helpers import DummyEvent, DummyFilter, dummy_nostr_client +from nostr.backup_models import KIND_MANIFEST, KIND_SNAPSHOT_CHUNK +from nostr.client import MANIFEST_ID_PREFIX +from nostr_sdk import Keys + + +def test_fetch_snapshot_legacy_key_fallback(dummy_nostr_client, monkeypatch): + client, relay = dummy_nostr_client + + # Track legacy key generation + called = {"legacy": False} + + class LegacyKeys: + def private_key_hex(self): + return "3" * 64 + + def public_key_hex(self): + return "4" * 64 + + def fake_generate(): + called["legacy"] = True + return LegacyKeys() + + monkeypatch.setattr( + client.key_manager, "generate_legacy_nostr_keys", fake_generate, raising=False + ) + + expected_pubkey = Keys.parse("3" * 64).public_key() + + class RecordingFilter(DummyFilter): + def author(self, pk): + self.author_pk = pk + return self + + monkeypatch.setattr("nostr.client.Filter", RecordingFilter) + + chunk_bytes = b"chunkdata" + chunk_hash = hashlib.sha256(chunk_bytes).hexdigest() + manifest_json = json.dumps( + { + "ver": 1, + "algo": "gzip", + "chunks": [ + { + "id": "seedpass-chunk-0000", + "size": len(chunk_bytes), + "hash": chunk_hash, + "event_id": None, + } + ], + } + ) + manifest_event = DummyEvent( + KIND_MANIFEST, manifest_json, tags=[f"{MANIFEST_ID_PREFIX}fp"] + ) + chunk_event = DummyEvent( + KIND_SNAPSHOT_CHUNK, + base64.b64encode(chunk_bytes).decode("utf-8"), + tags=["seedpass-chunk-0000"], + ) + + call = {"count": 0, "authors": []} + + async def fake_fetch_events(f, _timeout): + call["count"] += 1 + call["authors"].append(getattr(f, "author_pk", None)) + if call["count"] <= 2: + return type("R", (), {"to_vec": lambda self: []})() + elif call["count"] == 3: + return type("R", (), {"to_vec": lambda self: [manifest_event]})() + else: + return type("R", (), {"to_vec": lambda self: [chunk_event]})() + + monkeypatch.setattr(relay, "fetch_events", fake_fetch_events) + + result = asyncio.run(client.fetch_latest_snapshot()) + assert called["legacy"] + assert result is not None + manifest, chunks = result + assert b"".join(chunks) == chunk_bytes + assert call["authors"][-1].to_hex() == expected_pubkey.to_hex()