# /src/seedpass.core/encryption.py import logging import unicodedata try: import orjson as json_lib # type: ignore JSONDecodeError = orjson.JSONDecodeError USE_ORJSON = True except Exception: # pragma: no cover - fallback for environments without orjson import json as json_lib from json import JSONDecodeError USE_ORJSON = False import hashlib import os import base64 from dataclasses import asdict from pathlib import Path from typing import Optional, Tuple from cryptography.hazmat.primitives.ciphers.aead import AESGCM from cryptography.exceptions import InvalidTag from cryptography.fernet import Fernet, InvalidToken from termcolor import colored from utils.file_lock import exclusive_lock from mnemonic import Mnemonic from utils.password_prompt import prompt_existing_password from utils.key_derivation import KdfConfig, CURRENT_KDF_VERSION # Instantiate the logger logger = logging.getLogger(__name__) def _derive_legacy_key_from_password(password: str, iterations: int = 100_000) -> bytes: """Derive legacy Fernet key using password only (no fingerprint).""" normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8") key = hashlib.pbkdf2_hmac("sha256", normalized, b"", iterations, dklen=32) return base64.urlsafe_b64encode(key) class LegacyFormatRequiresMigrationError(Exception): """Raised when legacy-encrypted data needs user-guided migration.""" def __init__(self, context: Optional[str] = None) -> None: msg = ( f"Legacy data detected for {context}" if context else "Legacy data detected" ) super().__init__(msg) self.context = context class EncryptionManager: """ Manages encryption and decryption, handling migration from legacy Fernet to modern AES-GCM. """ def __init__(self, encryption_key: bytes, fingerprint_dir: Path): """ Initializes the EncryptionManager with keys for both new (AES-GCM) and legacy (Fernet) encryption formats. Parameters: encryption_key (bytes): A base64-encoded key. fingerprint_dir (Path): The directory corresponding to the fingerprint. """ self.fingerprint_dir = fingerprint_dir self.parent_seed_file = self.fingerprint_dir / "parent_seed.enc" try: if isinstance(encryption_key, str): encryption_key = encryption_key.encode() # (1) Keep both the legacy Fernet instance and the new AES-GCM cipher ready. self.key_b64 = encryption_key self.fernet = Fernet(self.key_b64) self.key = base64.urlsafe_b64decode(self.key_b64) self.cipher = AESGCM(self.key) logger.debug(f"EncryptionManager initialized for {self.fingerprint_dir}") except Exception as e: logger.error( f"Failed to initialize ciphers with provided encryption key: {e}", exc_info=True, ) raise # Track user preference for handling legacy indexes self._legacy_migrate_flag = True self.last_migration_performed = False def encrypt_data(self, data: bytes) -> bytes: """ (2) Encrypts data using the NEW AES-GCM format, prepending a version header and the nonce. All new data will be in this format. """ try: nonce = os.urandom(12) # 96-bit nonce is recommended for AES-GCM ciphertext = self.cipher.encrypt(nonce, data, None) return b"V2:" + nonce + ciphertext except Exception as e: logger.error(f"Failed to encrypt data: {e}", exc_info=True) raise def decrypt_data( self, encrypted_data: bytes, context: Optional[str] = None ) -> bytes: """Decrypt ``encrypted_data`` handling legacy fallbacks. Parameters ---------- encrypted_data: The bytes to decrypt. context: Optional string describing what is being decrypted ("seed", "index", etc.) for clearer error messages. """ ctx = f" {context}" if context else "" try: # Try the new V2 format first if encrypted_data.startswith(b"V2:"): try: nonce = encrypted_data[3:15] ciphertext = encrypted_data[15:] if len(ciphertext) < 16: logger.error("AES-GCM payload too short") raise InvalidToken("AES-GCM payload too short") return self.cipher.decrypt(nonce, ciphertext, None) except InvalidTag as e: logger.debug( "AES-GCM decryption failed: Invalid authentication tag." ) try: result = self.fernet.decrypt(encrypted_data[3:]) logger.warning( "Legacy-format file had incorrect 'V2:' header; decrypted with Fernet" ) return result except InvalidToken: msg = f"Failed to decrypt{ctx}: invalid key or corrupt file" logger.error(msg) raise InvalidToken(msg) from e # If it's not V2, it must be the legacy Fernet format else: logger.warning("Data is in legacy Fernet format. Attempting migration.") try: return self.fernet.decrypt(encrypted_data) except InvalidToken as e: logger.error( "Legacy Fernet decryption failed. Vault may be corrupt or key is incorrect." ) raise e except (InvalidToken, InvalidTag) as e: if encrypted_data.startswith(b"V2:"): # Already determined not to be legacy; re-raise raise if isinstance(e, InvalidToken) and str(e) == "AES-GCM payload too short": raise if not self._legacy_migrate_flag: raise logger.debug(f"Could not decrypt data{ctx}: {e}") raise LegacyFormatRequiresMigrationError(context) def decrypt_legacy( self, encrypted_data: bytes, password: str, context: Optional[str] = None ) -> bytes: """Decrypt ``encrypted_data`` using legacy password-only key derivation.""" ctx = f" {context}" if context else "" last_exc: Optional[Exception] = None for iter_count in [50_000, 100_000]: try: legacy_key = _derive_legacy_key_from_password( password, iterations=iter_count ) legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir) legacy_mgr._legacy_migrate_flag = False result = legacy_mgr.decrypt_data(encrypted_data, context=context) try: # record iteration count for future runs from .vault import Vault from .config_manager import ConfigManager cfg_mgr = ConfigManager( Vault(self, self.fingerprint_dir), self.fingerprint_dir ) cfg_mgr.set_kdf_iterations(iter_count) except Exception: # pragma: no cover - best effort logger.error( "Failed to record PBKDF2 iteration count in config", exc_info=True, ) logger.warning( "Data decrypted using legacy password-only key derivation." ) return result except Exception as e2: # pragma: no cover - try next iteration last_exc = e2 logger.error(f"Failed legacy decryption attempt: {last_exc}", exc_info=True) raise InvalidToken( f"Could not decrypt{ctx} with any available method." ) from last_exc # --- All functions below this point now use the smart `decrypt_data` method --- def resolve_relative_path(self, relative_path: Path) -> Path: """Resolve ``relative_path`` within ``fingerprint_dir`` and validate it. Parameters ---------- relative_path: The user-supplied path relative to ``fingerprint_dir``. Returns ------- Path The normalized absolute path inside ``fingerprint_dir``. Raises ------ ValueError If the resulting path is absolute or escapes ``fingerprint_dir``. """ candidate = (self.fingerprint_dir / relative_path).resolve() if not candidate.is_relative_to(self.fingerprint_dir.resolve()): raise ValueError("Invalid path outside fingerprint directory") return candidate def encrypt_parent_seed( self, parent_seed: str, kdf: Optional[KdfConfig] = None ) -> None: """Encrypts and saves the parent seed to 'parent_seed.enc'.""" data = parent_seed.encode("utf-8") self.encrypt_and_save_file(data, self.parent_seed_file, kdf=kdf) logger.info(f"Parent seed encrypted and saved to '{self.parent_seed_file}'.") def decrypt_parent_seed(self) -> str: """Decrypts and returns the parent seed, handling migration.""" with exclusive_lock(self.parent_seed_file) as fh: fh.seek(0) blob = fh.read() kdf, encrypted_data = self._deserialize(blob) is_legacy = not encrypted_data.startswith(b"V2:") decrypted_data = self.decrypt_data(encrypted_data, context="seed") if is_legacy: logger.info("Parent seed was in legacy format. Re-encrypting to V2 format.") self.encrypt_parent_seed(decrypted_data.decode("utf-8").strip(), kdf=kdf) return decrypted_data.decode("utf-8").strip() def _serialize(self, kdf: KdfConfig, ciphertext: bytes) -> bytes: payload = {"kdf": asdict(kdf), "ct": base64.b64encode(ciphertext).decode()} if USE_ORJSON: return json_lib.dumps(payload) return json_lib.dumps(payload, separators=(",", ":")).encode("utf-8") def _deserialize(self, blob: bytes) -> Tuple[KdfConfig, bytes]: """Return ``(KdfConfig, ciphertext)`` from serialized *blob*. Legacy files stored the raw ciphertext without a JSON wrapper. If decoding the wrapper fails, treat ``blob`` as the ciphertext and return a default HKDF configuration. """ try: if USE_ORJSON: obj = json_lib.loads(blob) else: obj = json_lib.loads(blob.decode("utf-8")) kdf = KdfConfig(**obj.get("kdf", {})) ct_b64 = obj.get("ct", "") ciphertext = base64.b64decode(ct_b64) if ciphertext: return kdf, ciphertext except Exception: # pragma: no cover - fall back to legacy path pass # Legacy format: ``blob`` already contains the ciphertext return ( KdfConfig(name="hkdf", version=CURRENT_KDF_VERSION, params={}, salt_b64=""), blob, ) def encrypt_and_save_file( self, data: bytes, relative_path: Path, *, kdf: Optional[KdfConfig] = None ) -> None: if kdf is None: kdf = KdfConfig() file_path = self.resolve_relative_path(relative_path) file_path.parent.mkdir(parents=True, exist_ok=True) encrypted_data = self.encrypt_data(data) payload = self._serialize(kdf, encrypted_data) with exclusive_lock(file_path) as fh: fh.seek(0) fh.truncate() fh.write(payload) fh.flush() os.fsync(fh.fileno()) os.chmod(file_path, 0o600) def decrypt_file(self, relative_path: Path) -> bytes: file_path = self.resolve_relative_path(relative_path) with exclusive_lock(file_path) as fh: fh.seek(0) blob = fh.read() _, encrypted_data = self._deserialize(blob) return self.decrypt_data(encrypted_data, context=str(relative_path)) def get_file_kdf(self, relative_path: Path) -> KdfConfig: file_path = self.resolve_relative_path(relative_path) with exclusive_lock(file_path) as fh: fh.seek(0) blob = fh.read() kdf, _ = self._deserialize(blob) return kdf def save_json_data( self, data: dict, relative_path: Optional[Path] = None, *, kdf: Optional[KdfConfig] = None, ) -> None: if relative_path is None: relative_path = Path("seedpass_entries_db.json.enc") if USE_ORJSON: json_data = json_lib.dumps(data) else: json_data = json_lib.dumps(data, separators=(",", ":")).encode("utf-8") self.encrypt_and_save_file(json_data, relative_path, kdf=kdf) logger.debug(f"JSON data encrypted and saved to '{relative_path}'.") def load_json_data( self, relative_path: Optional[Path] = None, *, return_kdf: bool = False ) -> dict | Tuple[dict, KdfConfig]: """ Loads and decrypts JSON data, automatically migrating and re-saving if it's in the legacy format. """ if relative_path is None: relative_path = Path("seedpass_entries_db.json.enc") file_path = self.resolve_relative_path(relative_path) if not file_path.exists(): empty: dict = {"entries": {}} if return_kdf: return empty, KdfConfig( name="hkdf", version=CURRENT_KDF_VERSION, params={}, salt_b64="" ) return empty with exclusive_lock(file_path) as fh: fh.seek(0) blob = fh.read() kdf, encrypted_data = self._deserialize(blob) is_legacy = not encrypted_data.startswith(b"V2:") self.last_migration_performed = False try: decrypted_data = self.decrypt_data( encrypted_data, context=str(relative_path) ) if USE_ORJSON: data = json_lib.loads(decrypted_data) else: data = json_lib.loads(decrypted_data.decode("utf-8")) # If it was a legacy file, re-save it in the new format now if is_legacy and self._legacy_migrate_flag: logger.info(f"Migrating and re-saving legacy vault file: {file_path}") self.save_json_data(data, relative_path, kdf=kdf) self.update_checksum(relative_path) self.last_migration_performed = True if return_kdf: return data, kdf return data except (InvalidToken, InvalidTag, JSONDecodeError) as e: logger.error( f"FATAL: Could not decrypt or parse data from {file_path}: {e}", exc_info=True, ) raise def get_encrypted_index(self) -> Optional[bytes]: relative_path = Path("seedpass_entries_db.json.enc") file_path = self.resolve_relative_path(relative_path) if not file_path.exists(): return None with exclusive_lock(file_path) as fh: fh.seek(0) return fh.read() def decrypt_and_save_index_from_nostr( self, encrypted_data: bytes, relative_path: Optional[Path] = None, *, strict: bool = True, merge: bool = False, ) -> bool: """Decrypts data from Nostr and saves it. Parameters ---------- encrypted_data: The payload downloaded from Nostr. relative_path: Destination filename under the profile directory. strict: When ``True`` (default) re-raise any decryption error. When ``False`` return ``False`` if decryption fails. """ if relative_path is None: relative_path = Path("seedpass_entries_db.json.enc") kdf, ciphertext = self._deserialize(encrypted_data) is_legacy = not ciphertext.startswith(b"V2:") self.last_migration_performed = False def _process(decrypted: bytes) -> dict: if USE_ORJSON: data = json_lib.loads(decrypted) else: data = json_lib.loads(decrypted.decode("utf-8")) existing_file = self.resolve_relative_path(relative_path) if merge and existing_file.exists(): current = self.load_json_data(relative_path) current_entries = current.get("entries", {}) for idx, entry in data.get("entries", {}).items(): cur_ts = current_entries.get(idx, {}).get("modified_ts", 0) new_ts = entry.get("modified_ts", 0) if idx not in current_entries or new_ts >= cur_ts: current_entries[idx] = entry current["entries"] = current_entries if "schema_version" in data: current["schema_version"] = max( current.get("schema_version", 0), data.get("schema_version", 0) ) data = current return data try: decrypted_data = self.decrypt_data(ciphertext, context=str(relative_path)) data = _process(decrypted_data) self.save_json_data(data, relative_path, kdf=kdf) self.update_checksum(relative_path) logger.info("Index file from Nostr was processed and saved successfully.") self.last_migration_performed = is_legacy return True except (InvalidToken, LegacyFormatRequiresMigrationError): try: password = prompt_existing_password( "Enter your master password for legacy decryption: " ) decrypted_data = self.decrypt_legacy( ciphertext, password, context=str(relative_path) ) data = _process(decrypted_data) self.save_json_data(data, relative_path, kdf=kdf) self.update_checksum(relative_path) logger.warning( "Index decrypted using legacy password-only key derivation." ) print( colored( "Warning: index decrypted with legacy key; it will be re-encrypted.", "yellow", ) ) self.last_migration_performed = True return True except Exception as e2: if strict: logger.error( f"Failed legacy decryption attempt: {e2}", exc_info=True, ) print( colored( f"Error: Failed to decrypt and save data from Nostr: {e2}", "red", ) ) raise logger.warning(f"Failed to decrypt index from Nostr: {e2}") return False except Exception as e: # pragma: no cover - error handling if strict: logger.error( f"Failed to decrypt and save data from Nostr: {e}", exc_info=True, ) print( colored( f"Error: Failed to decrypt and save data from Nostr: {e}", "red", ) ) raise logger.warning(f"Failed to decrypt index from Nostr: {e}") return False def update_checksum(self, relative_path: Optional[Path] = None) -> None: """Updates the checksum file for the specified file.""" if relative_path is None: relative_path = Path("seedpass_entries_db.json.enc") file_path = self.resolve_relative_path(relative_path) if not file_path.exists(): return try: with exclusive_lock(file_path) as fh: fh.seek(0) encrypted_bytes = fh.read() checksum = hashlib.sha256(encrypted_bytes).hexdigest() # Build checksum path by stripping both `.json` and `.enc` checksum_base = file_path.with_suffix("").with_suffix("") checksum_file = checksum_base.parent / f"{checksum_base.name}_checksum.txt" # Remove legacy checksum file if present legacy_checksum = file_path.parent / f"{file_path.stem}_checksum.txt" if legacy_checksum != checksum_file and legacy_checksum.exists(): try: legacy_checksum.unlink() except Exception: logger.warning( f"Could not remove legacy checksum file '{legacy_checksum}'", exc_info=True, ) with exclusive_lock(checksum_file) as fh: fh.seek(0) fh.truncate() fh.write(checksum.encode("utf-8")) fh.flush() os.fsync(fh.fileno()) os.chmod(checksum_file, 0o600) except Exception as e: logger.error( f"Failed to update checksum for '{relative_path}': {e}", exc_info=True, ) raise def validate_seed(self, seed_phrase: str) -> tuple[bool, Optional[str]]: """Validate a BIP-39 mnemonic. Returns a tuple of ``(is_valid, error_message)`` where ``error_message`` is ``None`` when the mnemonic is valid. """ try: if Mnemonic("english").check(seed_phrase): logger.debug("Seed phrase validated successfully.") return True, None logger.error("Seed phrase failed BIP-39 validation.") return False, "Invalid seed phrase." except Exception as e: logger.error(f"Error validating seed phrase: {e}", exc_info=True) return False, f"Failed to validate seed phrase: {e}" def derive_seed_from_mnemonic(self, mnemonic: str, passphrase: str = "") -> bytes: try: if not isinstance(mnemonic, str): if isinstance(mnemonic, list): mnemonic = " ".join(mnemonic) else: mnemonic = str(mnemonic) if not isinstance(mnemonic, str): raise TypeError("Mnemonic must be a string after conversion") from bip_utils import Bip39SeedGenerator seed = Bip39SeedGenerator(mnemonic).Generate(passphrase) logger.debug("Seed derived successfully from mnemonic.") return seed except Exception as e: logger.error(f"Failed to derive seed from mnemonic: {e}", exc_info=True) print(colored(f"Error: Failed to derive seed from mnemonic: {e}", "red")) raise