mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-09 07:48:57 +00:00
Merge pull request #758 from PR0M3TH3AN/codex/refactor-encryptionmanager.decrypt_data
Refactor decrypt_data error handling
This commit is contained in:
@@ -148,7 +148,9 @@ class VaultService:
|
|||||||
"""Restore a profile from ``data`` and sync."""
|
"""Restore a profile from ``data`` and sync."""
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
decrypted = self._manager.vault.encryption_manager.decrypt_data(data)
|
decrypted = self._manager.vault.encryption_manager.decrypt_data(
|
||||||
|
data, context="profile"
|
||||||
|
)
|
||||||
index = json.loads(decrypted.decode("utf-8"))
|
index = json.loads(decrypted.decode("utf-8"))
|
||||||
self._manager.vault.save_index(index)
|
self._manager.vault.save_index(index)
|
||||||
self._manager.sync_vault()
|
self._manager.sync_vault()
|
||||||
|
@@ -92,11 +92,22 @@ class EncryptionManager:
|
|||||||
logger.error(f"Failed to encrypt data: {e}", exc_info=True)
|
logger.error(f"Failed to encrypt data: {e}", exc_info=True)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
def decrypt_data(self, encrypted_data: bytes) -> bytes:
|
def decrypt_data(
|
||||||
"""
|
self, encrypted_data: bytes, context: Optional[str] = None
|
||||||
(3) The core migration logic. Tries the new format first, then falls back
|
) -> bytes:
|
||||||
to the old one. This is the ONLY place decryption logic should live.
|
"""Decrypt ``encrypted_data`` handling legacy fallbacks.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
encrypted_data:
|
||||||
|
The bytes to decrypt.
|
||||||
|
context:
|
||||||
|
Optional string describing what is being decrypted ("seed", "index", etc.)
|
||||||
|
for clearer error messages.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
ctx = f" {context}" if context else ""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Try the new V2 format first
|
# Try the new V2 format first
|
||||||
if encrypted_data.startswith(b"V2:"):
|
if encrypted_data.startswith(b"V2:"):
|
||||||
@@ -118,7 +129,9 @@ class EncryptionManager:
|
|||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
except InvalidToken:
|
except InvalidToken:
|
||||||
raise InvalidToken("AES-GCM decryption failed.") from e
|
msg = f"Failed to decrypt{ctx}: invalid key or corrupt file"
|
||||||
|
logger.error(msg)
|
||||||
|
raise InvalidToken(msg) from e
|
||||||
|
|
||||||
# If it's not V2, it must be the legacy Fernet format
|
# If it's not V2, it must be the legacy Fernet format
|
||||||
else:
|
else:
|
||||||
@@ -129,19 +142,20 @@ class EncryptionManager:
|
|||||||
logger.error(
|
logger.error(
|
||||||
"Legacy Fernet decryption failed. Vault may be corrupt or key is incorrect."
|
"Legacy Fernet decryption failed. Vault may be corrupt or key is incorrect."
|
||||||
)
|
)
|
||||||
raise InvalidToken(
|
raise e
|
||||||
"Could not decrypt data with any available method."
|
|
||||||
) from e
|
|
||||||
|
|
||||||
except (InvalidToken, InvalidTag) as e:
|
except (InvalidToken, InvalidTag) as e:
|
||||||
|
if encrypted_data.startswith(b"V2:"):
|
||||||
|
# Already determined not to be legacy; re-raise
|
||||||
|
raise
|
||||||
if isinstance(e, InvalidToken) and str(e) == "AES-GCM payload too short":
|
if isinstance(e, InvalidToken) and str(e) == "AES-GCM payload too short":
|
||||||
raise
|
raise
|
||||||
if not self._legacy_migrate_flag:
|
if not self._legacy_migrate_flag:
|
||||||
raise
|
raise
|
||||||
logger.debug(f"Could not decrypt data: {e}")
|
logger.debug(f"Could not decrypt data{ctx}: {e}")
|
||||||
print(
|
print(
|
||||||
colored(
|
colored(
|
||||||
"Failed to decrypt with current key. This may be a legacy index.",
|
f"Failed to decrypt{ctx} with current key. This may be a legacy index.",
|
||||||
"red",
|
"red",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -172,7 +186,7 @@ class EncryptionManager:
|
|||||||
)
|
)
|
||||||
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
|
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
|
||||||
legacy_mgr._legacy_migrate_flag = False
|
legacy_mgr._legacy_migrate_flag = False
|
||||||
result = legacy_mgr.decrypt_data(encrypted_data)
|
result = legacy_mgr.decrypt_data(encrypted_data, context=context)
|
||||||
try: # record iteration count for future runs
|
try: # record iteration count for future runs
|
||||||
from .vault import Vault
|
from .vault import Vault
|
||||||
from .config_manager import ConfigManager
|
from .config_manager import ConfigManager
|
||||||
@@ -194,7 +208,7 @@ class EncryptionManager:
|
|||||||
last_exc = e2
|
last_exc = e2
|
||||||
logger.error(f"Failed legacy decryption attempt: {last_exc}", exc_info=True)
|
logger.error(f"Failed legacy decryption attempt: {last_exc}", exc_info=True)
|
||||||
raise InvalidToken(
|
raise InvalidToken(
|
||||||
"Could not decrypt data with any available method."
|
f"Could not decrypt{ctx} with any available method."
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# --- All functions below this point now use the smart `decrypt_data` method ---
|
# --- All functions below this point now use the smart `decrypt_data` method ---
|
||||||
@@ -241,7 +255,7 @@ class EncryptionManager:
|
|||||||
encrypted_data = fh.read()
|
encrypted_data = fh.read()
|
||||||
|
|
||||||
is_legacy = not encrypted_data.startswith(b"V2:")
|
is_legacy = not encrypted_data.startswith(b"V2:")
|
||||||
decrypted_data = self.decrypt_data(encrypted_data)
|
decrypted_data = self.decrypt_data(encrypted_data, context="seed")
|
||||||
|
|
||||||
if is_legacy:
|
if is_legacy:
|
||||||
logger.info("Parent seed was in legacy format. Re-encrypting to V2 format.")
|
logger.info("Parent seed was in legacy format. Re-encrypting to V2 format.")
|
||||||
@@ -266,7 +280,7 @@ class EncryptionManager:
|
|||||||
with exclusive_lock(file_path) as fh:
|
with exclusive_lock(file_path) as fh:
|
||||||
fh.seek(0)
|
fh.seek(0)
|
||||||
encrypted_data = fh.read()
|
encrypted_data = fh.read()
|
||||||
return self.decrypt_data(encrypted_data)
|
return self.decrypt_data(encrypted_data, context=str(relative_path))
|
||||||
|
|
||||||
def save_json_data(self, data: dict, relative_path: Optional[Path] = None) -> None:
|
def save_json_data(self, data: dict, relative_path: Optional[Path] = None) -> None:
|
||||||
if relative_path is None:
|
if relative_path is None:
|
||||||
@@ -297,7 +311,9 @@ class EncryptionManager:
|
|||||||
self.last_migration_performed = False
|
self.last_migration_performed = False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decrypted_data = self.decrypt_data(encrypted_data)
|
decrypted_data = self.decrypt_data(
|
||||||
|
encrypted_data, context=str(relative_path)
|
||||||
|
)
|
||||||
if USE_ORJSON:
|
if USE_ORJSON:
|
||||||
data = json_lib.loads(decrypted_data)
|
data = json_lib.loads(decrypted_data)
|
||||||
else:
|
else:
|
||||||
@@ -376,7 +392,9 @@ class EncryptionManager:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
try:
|
try:
|
||||||
decrypted_data = self.decrypt_data(encrypted_data)
|
decrypted_data = self.decrypt_data(
|
||||||
|
encrypted_data, context=str(relative_path)
|
||||||
|
)
|
||||||
data = _process(decrypted_data)
|
data = _process(decrypted_data)
|
||||||
self.save_json_data(data, relative_path) # This always saves in V2 format
|
self.save_json_data(data, relative_path) # This always saves in V2 format
|
||||||
self.update_checksum(relative_path)
|
self.update_checksum(relative_path)
|
||||||
@@ -391,7 +409,9 @@ class EncryptionManager:
|
|||||||
)
|
)
|
||||||
legacy_key = _derive_legacy_key_from_password(password)
|
legacy_key = _derive_legacy_key_from_password(password)
|
||||||
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
|
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
|
||||||
decrypted_data = legacy_mgr.decrypt_data(encrypted_data)
|
decrypted_data = legacy_mgr.decrypt_data(
|
||||||
|
encrypted_data, context=str(relative_path)
|
||||||
|
)
|
||||||
data = _process(decrypted_data)
|
data = _process(decrypted_data)
|
||||||
self.save_json_data(data, relative_path)
|
self.save_json_data(data, relative_path)
|
||||||
self.update_checksum(relative_path)
|
self.update_checksum(relative_path)
|
||||||
|
@@ -112,7 +112,7 @@ def import_backup(
|
|||||||
|
|
||||||
raw = Path(path).read_bytes()
|
raw = Path(path).read_bytes()
|
||||||
if path.suffix.endswith(".enc"):
|
if path.suffix.endswith(".enc"):
|
||||||
raw = vault.encryption_manager.decrypt_data(raw)
|
raw = vault.encryption_manager.decrypt_data(raw, context=str(path))
|
||||||
|
|
||||||
wrapper = json.loads(raw.decode("utf-8"))
|
wrapper = json.loads(raw.decode("utf-8"))
|
||||||
if wrapper.get("format_version") != FORMAT_VERSION:
|
if wrapper.get("format_version") != FORMAT_VERSION:
|
||||||
@@ -129,7 +129,7 @@ def import_backup(
|
|||||||
)
|
)
|
||||||
key = _derive_export_key(seed)
|
key = _derive_export_key(seed)
|
||||||
enc_mgr = EncryptionManager(key, vault.fingerprint_dir)
|
enc_mgr = EncryptionManager(key, vault.fingerprint_dir)
|
||||||
index_bytes = enc_mgr.decrypt_data(payload)
|
index_bytes = enc_mgr.decrypt_data(payload, context="backup payload")
|
||||||
index = json.loads(index_bytes.decode("utf-8"))
|
index = json.loads(index_bytes.decode("utf-8"))
|
||||||
|
|
||||||
checksum = json_checksum(index)
|
checksum = json_checksum(index)
|
||||||
|
@@ -1,36 +0,0 @@
|
|||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import unicodedata
|
|
||||||
|
|
||||||
from helpers import TEST_PASSWORD
|
|
||||||
import seedpass.core.encryption as enc_module
|
|
||||||
from seedpass.core.encryption import EncryptionManager
|
|
||||||
from utils.key_derivation import derive_key_from_password
|
|
||||||
|
|
||||||
|
|
||||||
def test_decrypt_data_password_fallback(tmp_path, monkeypatch):
|
|
||||||
calls: list[int] = []
|
|
||||||
|
|
||||||
def _fast_legacy_key(password: str, iterations: int = 100_000) -> bytes:
|
|
||||||
calls.append(iterations)
|
|
||||||
normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8")
|
|
||||||
key = hashlib.pbkdf2_hmac("sha256", normalized, b"", 1, dklen=32)
|
|
||||||
return base64.urlsafe_b64encode(key)
|
|
||||||
|
|
||||||
monkeypatch.setattr(
|
|
||||||
enc_module, "_derive_legacy_key_from_password", _fast_legacy_key
|
|
||||||
)
|
|
||||||
monkeypatch.setattr(
|
|
||||||
enc_module, "prompt_existing_password", lambda *_a, **_k: TEST_PASSWORD
|
|
||||||
)
|
|
||||||
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "1")
|
|
||||||
|
|
||||||
legacy_key = _fast_legacy_key(TEST_PASSWORD, iterations=50_000)
|
|
||||||
legacy_mgr = EncryptionManager(legacy_key, tmp_path)
|
|
||||||
payload = legacy_mgr.encrypt_data(b"secret")
|
|
||||||
|
|
||||||
new_key = derive_key_from_password(TEST_PASSWORD, "fp")
|
|
||||||
new_mgr = EncryptionManager(new_key, tmp_path)
|
|
||||||
|
|
||||||
assert new_mgr.decrypt_data(payload) == b"secret"
|
|
||||||
assert calls == [50_000, 50_000]
|
|
62
src/tests/test_decrypt_messages.py
Normal file
62
src/tests/test_decrypt_messages.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import unicodedata
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from cryptography.fernet import InvalidToken
|
||||||
|
|
||||||
|
from helpers import TEST_PASSWORD, TEST_SEED
|
||||||
|
from seedpass.core.encryption import EncryptionManager
|
||||||
|
from utils.key_derivation import derive_index_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_wrong_password_message(tmp_path):
|
||||||
|
key = derive_index_key(TEST_SEED)
|
||||||
|
mgr = EncryptionManager(key, tmp_path)
|
||||||
|
payload = mgr.encrypt_data(b"secret")
|
||||||
|
|
||||||
|
wrong_key = bytearray(key)
|
||||||
|
wrong_key[0] ^= 1
|
||||||
|
wrong_mgr = EncryptionManager(bytes(wrong_key), tmp_path)
|
||||||
|
|
||||||
|
with pytest.raises(InvalidToken, match="invalid key or corrupt file") as exc:
|
||||||
|
wrong_mgr.decrypt_data(payload, context="index")
|
||||||
|
assert "index" in str(exc.value)
|
||||||
|
|
||||||
|
|
||||||
|
def test_legacy_file_requires_migration_message(tmp_path, monkeypatch, capsys):
|
||||||
|
def _fast_legacy_key(password: str, iterations: int = 100_000) -> bytes:
|
||||||
|
normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8")
|
||||||
|
key = hashlib.pbkdf2_hmac("sha256", normalized, b"", 1, dklen=32)
|
||||||
|
return base64.urlsafe_b64encode(key)
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"seedpass.core.encryption._derive_legacy_key_from_password", _fast_legacy_key
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"seedpass.core.encryption.prompt_existing_password",
|
||||||
|
lambda *_a, **_k: TEST_PASSWORD,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr("builtins.input", lambda *_a, **_k: "1")
|
||||||
|
|
||||||
|
legacy_key = _fast_legacy_key(TEST_PASSWORD)
|
||||||
|
legacy_mgr = EncryptionManager(legacy_key, tmp_path)
|
||||||
|
token = legacy_mgr.fernet.encrypt(b"secret")
|
||||||
|
|
||||||
|
new_mgr = EncryptionManager(derive_index_key(TEST_SEED), tmp_path)
|
||||||
|
assert new_mgr.decrypt_data(token, context="index") == b"secret"
|
||||||
|
|
||||||
|
out = capsys.readouterr().out
|
||||||
|
assert "Failed to decrypt index" in out
|
||||||
|
assert "legacy index" in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_corrupted_data_message(tmp_path):
|
||||||
|
key = derive_index_key(TEST_SEED)
|
||||||
|
mgr = EncryptionManager(key, tmp_path)
|
||||||
|
payload = bytearray(mgr.encrypt_data(b"secret"))
|
||||||
|
payload[-1] ^= 0xFF
|
||||||
|
|
||||||
|
with pytest.raises(InvalidToken, match="invalid key or corrupt file") as exc:
|
||||||
|
mgr.decrypt_data(bytes(payload), context="index")
|
||||||
|
assert "index" in str(exc.value)
|
Reference in New Issue
Block a user