mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-08 07:18:47 +00:00
Refactor decrypt_data error handling
This commit is contained in:
@@ -148,7 +148,9 @@ class VaultService:
|
||||
"""Restore a profile from ``data`` and sync."""
|
||||
|
||||
with self._lock:
|
||||
decrypted = self._manager.vault.encryption_manager.decrypt_data(data)
|
||||
decrypted = self._manager.vault.encryption_manager.decrypt_data(
|
||||
data, context="profile"
|
||||
)
|
||||
index = json.loads(decrypted.decode("utf-8"))
|
||||
self._manager.vault.save_index(index)
|
||||
self._manager.sync_vault()
|
||||
|
@@ -92,11 +92,22 @@ class EncryptionManager:
|
||||
logger.error(f"Failed to encrypt data: {e}", exc_info=True)
|
||||
raise
|
||||
|
||||
def decrypt_data(self, encrypted_data: bytes) -> bytes:
|
||||
"""
|
||||
(3) The core migration logic. Tries the new format first, then falls back
|
||||
to the old one. This is the ONLY place decryption logic should live.
|
||||
def decrypt_data(
|
||||
self, encrypted_data: bytes, context: Optional[str] = None
|
||||
) -> bytes:
|
||||
"""Decrypt ``encrypted_data`` handling legacy fallbacks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
encrypted_data:
|
||||
The bytes to decrypt.
|
||||
context:
|
||||
Optional string describing what is being decrypted ("seed", "index", etc.)
|
||||
for clearer error messages.
|
||||
"""
|
||||
|
||||
ctx = f" {context}" if context else ""
|
||||
|
||||
try:
|
||||
# Try the new V2 format first
|
||||
if encrypted_data.startswith(b"V2:"):
|
||||
@@ -118,7 +129,9 @@ class EncryptionManager:
|
||||
)
|
||||
return result
|
||||
except InvalidToken:
|
||||
raise InvalidToken("AES-GCM decryption failed.") from e
|
||||
msg = f"Failed to decrypt{ctx}: invalid key or corrupt file"
|
||||
logger.error(msg)
|
||||
raise InvalidToken(msg) from e
|
||||
|
||||
# If it's not V2, it must be the legacy Fernet format
|
||||
else:
|
||||
@@ -129,19 +142,20 @@ class EncryptionManager:
|
||||
logger.error(
|
||||
"Legacy Fernet decryption failed. Vault may be corrupt or key is incorrect."
|
||||
)
|
||||
raise InvalidToken(
|
||||
"Could not decrypt data with any available method."
|
||||
) from e
|
||||
raise e
|
||||
|
||||
except (InvalidToken, InvalidTag) as e:
|
||||
if encrypted_data.startswith(b"V2:"):
|
||||
# Already determined not to be legacy; re-raise
|
||||
raise
|
||||
if isinstance(e, InvalidToken) and str(e) == "AES-GCM payload too short":
|
||||
raise
|
||||
if not self._legacy_migrate_flag:
|
||||
raise
|
||||
logger.debug(f"Could not decrypt data: {e}")
|
||||
logger.debug(f"Could not decrypt data{ctx}: {e}")
|
||||
print(
|
||||
colored(
|
||||
"Failed to decrypt with current key. This may be a legacy index.",
|
||||
f"Failed to decrypt{ctx} with current key. This may be a legacy index.",
|
||||
"red",
|
||||
)
|
||||
)
|
||||
@@ -172,7 +186,7 @@ class EncryptionManager:
|
||||
)
|
||||
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
|
||||
legacy_mgr._legacy_migrate_flag = False
|
||||
result = legacy_mgr.decrypt_data(encrypted_data)
|
||||
result = legacy_mgr.decrypt_data(encrypted_data, context=context)
|
||||
try: # record iteration count for future runs
|
||||
from .vault import Vault
|
||||
from .config_manager import ConfigManager
|
||||
@@ -194,7 +208,7 @@ class EncryptionManager:
|
||||
last_exc = e2
|
||||
logger.error(f"Failed legacy decryption attempt: {last_exc}", exc_info=True)
|
||||
raise InvalidToken(
|
||||
"Could not decrypt data with any available method."
|
||||
f"Could not decrypt{ctx} with any available method."
|
||||
) from e
|
||||
|
||||
# --- All functions below this point now use the smart `decrypt_data` method ---
|
||||
@@ -241,7 +255,7 @@ class EncryptionManager:
|
||||
encrypted_data = fh.read()
|
||||
|
||||
is_legacy = not encrypted_data.startswith(b"V2:")
|
||||
decrypted_data = self.decrypt_data(encrypted_data)
|
||||
decrypted_data = self.decrypt_data(encrypted_data, context="seed")
|
||||
|
||||
if is_legacy:
|
||||
logger.info("Parent seed was in legacy format. Re-encrypting to V2 format.")
|
||||
@@ -266,7 +280,7 @@ class EncryptionManager:
|
||||
with exclusive_lock(file_path) as fh:
|
||||
fh.seek(0)
|
||||
encrypted_data = fh.read()
|
||||
return self.decrypt_data(encrypted_data)
|
||||
return self.decrypt_data(encrypted_data, context=str(relative_path))
|
||||
|
||||
def save_json_data(self, data: dict, relative_path: Optional[Path] = None) -> None:
|
||||
if relative_path is None:
|
||||
@@ -297,7 +311,9 @@ class EncryptionManager:
|
||||
self.last_migration_performed = False
|
||||
|
||||
try:
|
||||
decrypted_data = self.decrypt_data(encrypted_data)
|
||||
decrypted_data = self.decrypt_data(
|
||||
encrypted_data, context=str(relative_path)
|
||||
)
|
||||
if USE_ORJSON:
|
||||
data = json_lib.loads(decrypted_data)
|
||||
else:
|
||||
@@ -376,7 +392,9 @@ class EncryptionManager:
|
||||
return data
|
||||
|
||||
try:
|
||||
decrypted_data = self.decrypt_data(encrypted_data)
|
||||
decrypted_data = self.decrypt_data(
|
||||
encrypted_data, context=str(relative_path)
|
||||
)
|
||||
data = _process(decrypted_data)
|
||||
self.save_json_data(data, relative_path) # This always saves in V2 format
|
||||
self.update_checksum(relative_path)
|
||||
@@ -391,7 +409,9 @@ class EncryptionManager:
|
||||
)
|
||||
legacy_key = _derive_legacy_key_from_password(password)
|
||||
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
|
||||
decrypted_data = legacy_mgr.decrypt_data(encrypted_data)
|
||||
decrypted_data = legacy_mgr.decrypt_data(
|
||||
encrypted_data, context=str(relative_path)
|
||||
)
|
||||
data = _process(decrypted_data)
|
||||
self.save_json_data(data, relative_path)
|
||||
self.update_checksum(relative_path)
|
||||
|
@@ -112,7 +112,7 @@ def import_backup(
|
||||
|
||||
raw = Path(path).read_bytes()
|
||||
if path.suffix.endswith(".enc"):
|
||||
raw = vault.encryption_manager.decrypt_data(raw)
|
||||
raw = vault.encryption_manager.decrypt_data(raw, context=str(path))
|
||||
|
||||
wrapper = json.loads(raw.decode("utf-8"))
|
||||
if wrapper.get("format_version") != FORMAT_VERSION:
|
||||
@@ -129,7 +129,7 @@ def import_backup(
|
||||
)
|
||||
key = _derive_export_key(seed)
|
||||
enc_mgr = EncryptionManager(key, vault.fingerprint_dir)
|
||||
index_bytes = enc_mgr.decrypt_data(payload)
|
||||
index_bytes = enc_mgr.decrypt_data(payload, context="backup payload")
|
||||
index = json.loads(index_bytes.decode("utf-8"))
|
||||
|
||||
checksum = json_checksum(index)
|
||||
|
@@ -1,36 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import unicodedata
|
||||
|
||||
from helpers import TEST_PASSWORD
|
||||
import seedpass.core.encryption as enc_module
|
||||
from seedpass.core.encryption import EncryptionManager
|
||||
from utils.key_derivation import derive_key_from_password
|
||||
|
||||
|
||||
def test_decrypt_data_password_fallback(tmp_path, monkeypatch):
|
||||
calls: list[int] = []
|
||||
|
||||
def _fast_legacy_key(password: str, iterations: int = 100_000) -> bytes:
|
||||
calls.append(iterations)
|
||||
normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8")
|
||||
key = hashlib.pbkdf2_hmac("sha256", normalized, b"", 1, dklen=32)
|
||||
return base64.urlsafe_b64encode(key)
|
||||
|
||||
monkeypatch.setattr(
|
||||
enc_module, "_derive_legacy_key_from_password", _fast_legacy_key
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
enc_module, "prompt_existing_password", lambda *_a, **_k: TEST_PASSWORD
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "1")
|
||||
|
||||
legacy_key = _fast_legacy_key(TEST_PASSWORD, iterations=50_000)
|
||||
legacy_mgr = EncryptionManager(legacy_key, tmp_path)
|
||||
payload = legacy_mgr.encrypt_data(b"secret")
|
||||
|
||||
new_key = derive_key_from_password(TEST_PASSWORD, "fp")
|
||||
new_mgr = EncryptionManager(new_key, tmp_path)
|
||||
|
||||
assert new_mgr.decrypt_data(payload) == b"secret"
|
||||
assert calls == [50_000, 50_000]
|
62
src/tests/test_decrypt_messages.py
Normal file
62
src/tests/test_decrypt_messages.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import unicodedata
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import InvalidToken
|
||||
|
||||
from helpers import TEST_PASSWORD, TEST_SEED
|
||||
from seedpass.core.encryption import EncryptionManager
|
||||
from utils.key_derivation import derive_index_key
|
||||
|
||||
|
||||
def test_wrong_password_message(tmp_path):
|
||||
key = derive_index_key(TEST_SEED)
|
||||
mgr = EncryptionManager(key, tmp_path)
|
||||
payload = mgr.encrypt_data(b"secret")
|
||||
|
||||
wrong_key = bytearray(key)
|
||||
wrong_key[0] ^= 1
|
||||
wrong_mgr = EncryptionManager(bytes(wrong_key), tmp_path)
|
||||
|
||||
with pytest.raises(InvalidToken, match="invalid key or corrupt file") as exc:
|
||||
wrong_mgr.decrypt_data(payload, context="index")
|
||||
assert "index" in str(exc.value)
|
||||
|
||||
|
||||
def test_legacy_file_requires_migration_message(tmp_path, monkeypatch, capsys):
|
||||
def _fast_legacy_key(password: str, iterations: int = 100_000) -> bytes:
|
||||
normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8")
|
||||
key = hashlib.pbkdf2_hmac("sha256", normalized, b"", 1, dklen=32)
|
||||
return base64.urlsafe_b64encode(key)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"seedpass.core.encryption._derive_legacy_key_from_password", _fast_legacy_key
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"seedpass.core.encryption.prompt_existing_password",
|
||||
lambda *_a, **_k: TEST_PASSWORD,
|
||||
)
|
||||
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "1")
|
||||
|
||||
legacy_key = _fast_legacy_key(TEST_PASSWORD)
|
||||
legacy_mgr = EncryptionManager(legacy_key, tmp_path)
|
||||
token = legacy_mgr.fernet.encrypt(b"secret")
|
||||
|
||||
new_mgr = EncryptionManager(derive_index_key(TEST_SEED), tmp_path)
|
||||
assert new_mgr.decrypt_data(token, context="index") == b"secret"
|
||||
|
||||
out = capsys.readouterr().out
|
||||
assert "Failed to decrypt index" in out
|
||||
assert "legacy index" in out
|
||||
|
||||
|
||||
def test_corrupted_data_message(tmp_path):
|
||||
key = derive_index_key(TEST_SEED)
|
||||
mgr = EncryptionManager(key, tmp_path)
|
||||
payload = bytearray(mgr.encrypt_data(b"secret"))
|
||||
payload[-1] ^= 0xFF
|
||||
|
||||
with pytest.raises(InvalidToken, match="invalid key or corrupt file") as exc:
|
||||
mgr.decrypt_data(bytes(payload), context="index")
|
||||
assert "index" in str(exc.value)
|
Reference in New Issue
Block a user