diff --git a/src/seedpass/core/api.py b/src/seedpass/core/api.py index 28046cd..3198241 100644 --- a/src/seedpass/core/api.py +++ b/src/seedpass/core/api.py @@ -148,7 +148,9 @@ class VaultService: """Restore a profile from ``data`` and sync.""" with self._lock: - decrypted = self._manager.vault.encryption_manager.decrypt_data(data) + decrypted = self._manager.vault.encryption_manager.decrypt_data( + data, context="profile" + ) index = json.loads(decrypted.decode("utf-8")) self._manager.vault.save_index(index) self._manager.sync_vault() diff --git a/src/seedpass/core/encryption.py b/src/seedpass/core/encryption.py index f669ff7..93ba557 100644 --- a/src/seedpass/core/encryption.py +++ b/src/seedpass/core/encryption.py @@ -92,11 +92,22 @@ class EncryptionManager: logger.error(f"Failed to encrypt data: {e}", exc_info=True) raise - def decrypt_data(self, encrypted_data: bytes) -> bytes: - """ - (3) The core migration logic. Tries the new format first, then falls back - to the old one. This is the ONLY place decryption logic should live. + def decrypt_data( + self, encrypted_data: bytes, context: Optional[str] = None + ) -> bytes: + """Decrypt ``encrypted_data`` handling legacy fallbacks. + + Parameters + ---------- + encrypted_data: + The bytes to decrypt. + context: + Optional string describing what is being decrypted ("seed", "index", etc.) + for clearer error messages. """ + + ctx = f" {context}" if context else "" + try: # Try the new V2 format first if encrypted_data.startswith(b"V2:"): @@ -118,7 +129,9 @@ class EncryptionManager: ) return result except InvalidToken: - raise InvalidToken("AES-GCM decryption failed.") from e + msg = f"Failed to decrypt{ctx}: invalid key or corrupt file" + logger.error(msg) + raise InvalidToken(msg) from e # If it's not V2, it must be the legacy Fernet format else: @@ -129,19 +142,20 @@ class EncryptionManager: logger.error( "Legacy Fernet decryption failed. Vault may be corrupt or key is incorrect." ) - raise InvalidToken( - "Could not decrypt data with any available method." - ) from e + raise e except (InvalidToken, InvalidTag) as e: + if encrypted_data.startswith(b"V2:"): + # Already determined not to be legacy; re-raise + raise if isinstance(e, InvalidToken) and str(e) == "AES-GCM payload too short": raise if not self._legacy_migrate_flag: raise - logger.debug(f"Could not decrypt data: {e}") + logger.debug(f"Could not decrypt data{ctx}: {e}") print( colored( - "Failed to decrypt with current key. This may be a legacy index.", + f"Failed to decrypt{ctx} with current key. This may be a legacy index.", "red", ) ) @@ -172,7 +186,7 @@ class EncryptionManager: ) legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir) legacy_mgr._legacy_migrate_flag = False - result = legacy_mgr.decrypt_data(encrypted_data) + result = legacy_mgr.decrypt_data(encrypted_data, context=context) try: # record iteration count for future runs from .vault import Vault from .config_manager import ConfigManager @@ -194,7 +208,7 @@ class EncryptionManager: last_exc = e2 logger.error(f"Failed legacy decryption attempt: {last_exc}", exc_info=True) raise InvalidToken( - "Could not decrypt data with any available method." + f"Could not decrypt{ctx} with any available method." ) from e # --- All functions below this point now use the smart `decrypt_data` method --- @@ -241,7 +255,7 @@ class EncryptionManager: encrypted_data = fh.read() is_legacy = not encrypted_data.startswith(b"V2:") - decrypted_data = self.decrypt_data(encrypted_data) + decrypted_data = self.decrypt_data(encrypted_data, context="seed") if is_legacy: logger.info("Parent seed was in legacy format. Re-encrypting to V2 format.") @@ -266,7 +280,7 @@ class EncryptionManager: with exclusive_lock(file_path) as fh: fh.seek(0) encrypted_data = fh.read() - return self.decrypt_data(encrypted_data) + return self.decrypt_data(encrypted_data, context=str(relative_path)) def save_json_data(self, data: dict, relative_path: Optional[Path] = None) -> None: if relative_path is None: @@ -297,7 +311,9 @@ class EncryptionManager: self.last_migration_performed = False try: - decrypted_data = self.decrypt_data(encrypted_data) + decrypted_data = self.decrypt_data( + encrypted_data, context=str(relative_path) + ) if USE_ORJSON: data = json_lib.loads(decrypted_data) else: @@ -376,7 +392,9 @@ class EncryptionManager: return data try: - decrypted_data = self.decrypt_data(encrypted_data) + decrypted_data = self.decrypt_data( + encrypted_data, context=str(relative_path) + ) data = _process(decrypted_data) self.save_json_data(data, relative_path) # This always saves in V2 format self.update_checksum(relative_path) @@ -391,7 +409,9 @@ class EncryptionManager: ) legacy_key = _derive_legacy_key_from_password(password) legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir) - decrypted_data = legacy_mgr.decrypt_data(encrypted_data) + decrypted_data = legacy_mgr.decrypt_data( + encrypted_data, context=str(relative_path) + ) data = _process(decrypted_data) self.save_json_data(data, relative_path) self.update_checksum(relative_path) diff --git a/src/seedpass/core/portable_backup.py b/src/seedpass/core/portable_backup.py index a76879b..a8e4af5 100644 --- a/src/seedpass/core/portable_backup.py +++ b/src/seedpass/core/portable_backup.py @@ -112,7 +112,7 @@ def import_backup( raw = Path(path).read_bytes() if path.suffix.endswith(".enc"): - raw = vault.encryption_manager.decrypt_data(raw) + raw = vault.encryption_manager.decrypt_data(raw, context=str(path)) wrapper = json.loads(raw.decode("utf-8")) if wrapper.get("format_version") != FORMAT_VERSION: @@ -129,7 +129,7 @@ def import_backup( ) key = _derive_export_key(seed) enc_mgr = EncryptionManager(key, vault.fingerprint_dir) - index_bytes = enc_mgr.decrypt_data(payload) + index_bytes = enc_mgr.decrypt_data(payload, context="backup payload") index = json.loads(index_bytes.decode("utf-8")) checksum = json_checksum(index) diff --git a/src/tests/test_decrypt_data_legacy_fallback.py b/src/tests/test_decrypt_data_legacy_fallback.py deleted file mode 100644 index 6172eca..0000000 --- a/src/tests/test_decrypt_data_legacy_fallback.py +++ /dev/null @@ -1,36 +0,0 @@ -import base64 -import hashlib -import unicodedata - -from helpers import TEST_PASSWORD -import seedpass.core.encryption as enc_module -from seedpass.core.encryption import EncryptionManager -from utils.key_derivation import derive_key_from_password - - -def test_decrypt_data_password_fallback(tmp_path, monkeypatch): - calls: list[int] = [] - - def _fast_legacy_key(password: str, iterations: int = 100_000) -> bytes: - calls.append(iterations) - normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8") - key = hashlib.pbkdf2_hmac("sha256", normalized, b"", 1, dklen=32) - return base64.urlsafe_b64encode(key) - - monkeypatch.setattr( - enc_module, "_derive_legacy_key_from_password", _fast_legacy_key - ) - monkeypatch.setattr( - enc_module, "prompt_existing_password", lambda *_a, **_k: TEST_PASSWORD - ) - monkeypatch.setattr("builtins.input", lambda *_a, **_k: "1") - - legacy_key = _fast_legacy_key(TEST_PASSWORD, iterations=50_000) - legacy_mgr = EncryptionManager(legacy_key, tmp_path) - payload = legacy_mgr.encrypt_data(b"secret") - - new_key = derive_key_from_password(TEST_PASSWORD, "fp") - new_mgr = EncryptionManager(new_key, tmp_path) - - assert new_mgr.decrypt_data(payload) == b"secret" - assert calls == [50_000, 50_000] diff --git a/src/tests/test_decrypt_messages.py b/src/tests/test_decrypt_messages.py new file mode 100644 index 0000000..ff0f955 --- /dev/null +++ b/src/tests/test_decrypt_messages.py @@ -0,0 +1,62 @@ +import base64 +import hashlib +import unicodedata + +import pytest +from cryptography.fernet import InvalidToken + +from helpers import TEST_PASSWORD, TEST_SEED +from seedpass.core.encryption import EncryptionManager +from utils.key_derivation import derive_index_key + + +def test_wrong_password_message(tmp_path): + key = derive_index_key(TEST_SEED) + mgr = EncryptionManager(key, tmp_path) + payload = mgr.encrypt_data(b"secret") + + wrong_key = bytearray(key) + wrong_key[0] ^= 1 + wrong_mgr = EncryptionManager(bytes(wrong_key), tmp_path) + + with pytest.raises(InvalidToken, match="invalid key or corrupt file") as exc: + wrong_mgr.decrypt_data(payload, context="index") + assert "index" in str(exc.value) + + +def test_legacy_file_requires_migration_message(tmp_path, monkeypatch, capsys): + def _fast_legacy_key(password: str, iterations: int = 100_000) -> bytes: + normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8") + key = hashlib.pbkdf2_hmac("sha256", normalized, b"", 1, dklen=32) + return base64.urlsafe_b64encode(key) + + monkeypatch.setattr( + "seedpass.core.encryption._derive_legacy_key_from_password", _fast_legacy_key + ) + monkeypatch.setattr( + "seedpass.core.encryption.prompt_existing_password", + lambda *_a, **_k: TEST_PASSWORD, + ) + monkeypatch.setattr("builtins.input", lambda *_a, **_k: "1") + + legacy_key = _fast_legacy_key(TEST_PASSWORD) + legacy_mgr = EncryptionManager(legacy_key, tmp_path) + token = legacy_mgr.fernet.encrypt(b"secret") + + new_mgr = EncryptionManager(derive_index_key(TEST_SEED), tmp_path) + assert new_mgr.decrypt_data(token, context="index") == b"secret" + + out = capsys.readouterr().out + assert "Failed to decrypt index" in out + assert "legacy index" in out + + +def test_corrupted_data_message(tmp_path): + key = derive_index_key(TEST_SEED) + mgr = EncryptionManager(key, tmp_path) + payload = bytearray(mgr.encrypt_data(b"secret")) + payload[-1] ^= 0xFF + + with pytest.raises(InvalidToken, match="invalid key or corrupt file") as exc: + mgr.decrypt_data(bytes(payload), context="index") + assert "index" in str(exc.value)