Refactor decrypt_data error handling

This commit is contained in:
thePR0M3TH3AN
2025-08-04 13:10:49 -04:00
parent 054ffd3383
commit 4d7f28b400
5 changed files with 104 additions and 56 deletions

View File

@@ -148,7 +148,9 @@ class VaultService:
"""Restore a profile from ``data`` and sync."""
with self._lock:
decrypted = self._manager.vault.encryption_manager.decrypt_data(data)
decrypted = self._manager.vault.encryption_manager.decrypt_data(
data, context="profile"
)
index = json.loads(decrypted.decode("utf-8"))
self._manager.vault.save_index(index)
self._manager.sync_vault()

View File

@@ -92,11 +92,22 @@ class EncryptionManager:
logger.error(f"Failed to encrypt data: {e}", exc_info=True)
raise
def decrypt_data(self, encrypted_data: bytes) -> bytes:
"""
(3) The core migration logic. Tries the new format first, then falls back
to the old one. This is the ONLY place decryption logic should live.
def decrypt_data(
self, encrypted_data: bytes, context: Optional[str] = None
) -> bytes:
"""Decrypt ``encrypted_data`` handling legacy fallbacks.
Parameters
----------
encrypted_data:
The bytes to decrypt.
context:
Optional string describing what is being decrypted ("seed", "index", etc.)
for clearer error messages.
"""
ctx = f" {context}" if context else ""
try:
# Try the new V2 format first
if encrypted_data.startswith(b"V2:"):
@@ -118,7 +129,9 @@ class EncryptionManager:
)
return result
except InvalidToken:
raise InvalidToken("AES-GCM decryption failed.") from e
msg = f"Failed to decrypt{ctx}: invalid key or corrupt file"
logger.error(msg)
raise InvalidToken(msg) from e
# If it's not V2, it must be the legacy Fernet format
else:
@@ -129,19 +142,20 @@ class EncryptionManager:
logger.error(
"Legacy Fernet decryption failed. Vault may be corrupt or key is incorrect."
)
raise InvalidToken(
"Could not decrypt data with any available method."
) from e
raise e
except (InvalidToken, InvalidTag) as e:
if encrypted_data.startswith(b"V2:"):
# Already determined not to be legacy; re-raise
raise
if isinstance(e, InvalidToken) and str(e) == "AES-GCM payload too short":
raise
if not self._legacy_migrate_flag:
raise
logger.debug(f"Could not decrypt data: {e}")
logger.debug(f"Could not decrypt data{ctx}: {e}")
print(
colored(
"Failed to decrypt with current key. This may be a legacy index.",
f"Failed to decrypt{ctx} with current key. This may be a legacy index.",
"red",
)
)
@@ -172,7 +186,7 @@ class EncryptionManager:
)
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
legacy_mgr._legacy_migrate_flag = False
result = legacy_mgr.decrypt_data(encrypted_data)
result = legacy_mgr.decrypt_data(encrypted_data, context=context)
try: # record iteration count for future runs
from .vault import Vault
from .config_manager import ConfigManager
@@ -194,7 +208,7 @@ class EncryptionManager:
last_exc = e2
logger.error(f"Failed legacy decryption attempt: {last_exc}", exc_info=True)
raise InvalidToken(
"Could not decrypt data with any available method."
f"Could not decrypt{ctx} with any available method."
) from e
# --- All functions below this point now use the smart `decrypt_data` method ---
@@ -241,7 +255,7 @@ class EncryptionManager:
encrypted_data = fh.read()
is_legacy = not encrypted_data.startswith(b"V2:")
decrypted_data = self.decrypt_data(encrypted_data)
decrypted_data = self.decrypt_data(encrypted_data, context="seed")
if is_legacy:
logger.info("Parent seed was in legacy format. Re-encrypting to V2 format.")
@@ -266,7 +280,7 @@ class EncryptionManager:
with exclusive_lock(file_path) as fh:
fh.seek(0)
encrypted_data = fh.read()
return self.decrypt_data(encrypted_data)
return self.decrypt_data(encrypted_data, context=str(relative_path))
def save_json_data(self, data: dict, relative_path: Optional[Path] = None) -> None:
if relative_path is None:
@@ -297,7 +311,9 @@ class EncryptionManager:
self.last_migration_performed = False
try:
decrypted_data = self.decrypt_data(encrypted_data)
decrypted_data = self.decrypt_data(
encrypted_data, context=str(relative_path)
)
if USE_ORJSON:
data = json_lib.loads(decrypted_data)
else:
@@ -376,7 +392,9 @@ class EncryptionManager:
return data
try:
decrypted_data = self.decrypt_data(encrypted_data)
decrypted_data = self.decrypt_data(
encrypted_data, context=str(relative_path)
)
data = _process(decrypted_data)
self.save_json_data(data, relative_path) # This always saves in V2 format
self.update_checksum(relative_path)
@@ -391,7 +409,9 @@ class EncryptionManager:
)
legacy_key = _derive_legacy_key_from_password(password)
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
decrypted_data = legacy_mgr.decrypt_data(encrypted_data)
decrypted_data = legacy_mgr.decrypt_data(
encrypted_data, context=str(relative_path)
)
data = _process(decrypted_data)
self.save_json_data(data, relative_path)
self.update_checksum(relative_path)

View File

@@ -112,7 +112,7 @@ def import_backup(
raw = Path(path).read_bytes()
if path.suffix.endswith(".enc"):
raw = vault.encryption_manager.decrypt_data(raw)
raw = vault.encryption_manager.decrypt_data(raw, context=str(path))
wrapper = json.loads(raw.decode("utf-8"))
if wrapper.get("format_version") != FORMAT_VERSION:
@@ -129,7 +129,7 @@ def import_backup(
)
key = _derive_export_key(seed)
enc_mgr = EncryptionManager(key, vault.fingerprint_dir)
index_bytes = enc_mgr.decrypt_data(payload)
index_bytes = enc_mgr.decrypt_data(payload, context="backup payload")
index = json.loads(index_bytes.decode("utf-8"))
checksum = json_checksum(index)

View File

@@ -1,36 +0,0 @@
import base64
import hashlib
import unicodedata
from helpers import TEST_PASSWORD
import seedpass.core.encryption as enc_module
from seedpass.core.encryption import EncryptionManager
from utils.key_derivation import derive_key_from_password
def test_decrypt_data_password_fallback(tmp_path, monkeypatch):
calls: list[int] = []
def _fast_legacy_key(password: str, iterations: int = 100_000) -> bytes:
calls.append(iterations)
normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8")
key = hashlib.pbkdf2_hmac("sha256", normalized, b"", 1, dklen=32)
return base64.urlsafe_b64encode(key)
monkeypatch.setattr(
enc_module, "_derive_legacy_key_from_password", _fast_legacy_key
)
monkeypatch.setattr(
enc_module, "prompt_existing_password", lambda *_a, **_k: TEST_PASSWORD
)
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "1")
legacy_key = _fast_legacy_key(TEST_PASSWORD, iterations=50_000)
legacy_mgr = EncryptionManager(legacy_key, tmp_path)
payload = legacy_mgr.encrypt_data(b"secret")
new_key = derive_key_from_password(TEST_PASSWORD, "fp")
new_mgr = EncryptionManager(new_key, tmp_path)
assert new_mgr.decrypt_data(payload) == b"secret"
assert calls == [50_000, 50_000]

View File

@@ -0,0 +1,62 @@
import base64
import hashlib
import unicodedata
import pytest
from cryptography.fernet import InvalidToken
from helpers import TEST_PASSWORD, TEST_SEED
from seedpass.core.encryption import EncryptionManager
from utils.key_derivation import derive_index_key
def test_wrong_password_message(tmp_path):
key = derive_index_key(TEST_SEED)
mgr = EncryptionManager(key, tmp_path)
payload = mgr.encrypt_data(b"secret")
wrong_key = bytearray(key)
wrong_key[0] ^= 1
wrong_mgr = EncryptionManager(bytes(wrong_key), tmp_path)
with pytest.raises(InvalidToken, match="invalid key or corrupt file") as exc:
wrong_mgr.decrypt_data(payload, context="index")
assert "index" in str(exc.value)
def test_legacy_file_requires_migration_message(tmp_path, monkeypatch, capsys):
def _fast_legacy_key(password: str, iterations: int = 100_000) -> bytes:
normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8")
key = hashlib.pbkdf2_hmac("sha256", normalized, b"", 1, dklen=32)
return base64.urlsafe_b64encode(key)
monkeypatch.setattr(
"seedpass.core.encryption._derive_legacy_key_from_password", _fast_legacy_key
)
monkeypatch.setattr(
"seedpass.core.encryption.prompt_existing_password",
lambda *_a, **_k: TEST_PASSWORD,
)
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "1")
legacy_key = _fast_legacy_key(TEST_PASSWORD)
legacy_mgr = EncryptionManager(legacy_key, tmp_path)
token = legacy_mgr.fernet.encrypt(b"secret")
new_mgr = EncryptionManager(derive_index_key(TEST_SEED), tmp_path)
assert new_mgr.decrypt_data(token, context="index") == b"secret"
out = capsys.readouterr().out
assert "Failed to decrypt index" in out
assert "legacy index" in out
def test_corrupted_data_message(tmp_path):
key = derive_index_key(TEST_SEED)
mgr = EncryptionManager(key, tmp_path)
payload = bytearray(mgr.encrypt_data(b"secret"))
payload[-1] ^= 0xFF
with pytest.raises(InvalidToken, match="invalid key or corrupt file") as exc:
mgr.decrypt_data(bytes(payload), context="index")
assert "index" in str(exc.value)