Files
seedPass/src/seedpass/core/encryption.py
2025-08-20 17:23:10 -04:00

590 lines
23 KiB
Python

# /src/seedpass.core/encryption.py
import logging
import unicodedata
try:
import orjson as json_lib # type: ignore
JSONDecodeError = orjson.JSONDecodeError
USE_ORJSON = True
except Exception: # pragma: no cover - fallback for environments without orjson
import json as json_lib
from json import JSONDecodeError
USE_ORJSON = False
import hashlib
import os
import base64
from dataclasses import asdict
from pathlib import Path
from typing import Optional, Tuple
from cryptography.hazmat.primitives.ciphers.aead import AESGCM
from cryptography.exceptions import InvalidTag
from cryptography.fernet import Fernet, InvalidToken
from termcolor import colored
from utils.file_lock import exclusive_lock
from mnemonic import Mnemonic
from utils.password_prompt import prompt_existing_password
from utils.key_derivation import KdfConfig, CURRENT_KDF_VERSION
# Instantiate the logger
logger = logging.getLogger(__name__)
def _derive_legacy_key_from_password(password: str, iterations: int = 100_000) -> bytes:
"""Derive legacy Fernet key using password only (no fingerprint)."""
normalized = unicodedata.normalize("NFKD", password).strip().encode("utf-8")
key = hashlib.pbkdf2_hmac("sha256", normalized, b"", iterations, dklen=32)
return base64.urlsafe_b64encode(key)
class LegacyFormatRequiresMigrationError(Exception):
"""Raised when legacy-encrypted data needs user-guided migration."""
def __init__(self, context: Optional[str] = None) -> None:
msg = (
f"Legacy data detected for {context}" if context else "Legacy data detected"
)
super().__init__(msg)
self.context = context
class EncryptionManager:
"""
Manages encryption and decryption, handling migration from legacy Fernet
to modern AES-GCM.
"""
def __init__(self, encryption_key: bytes, fingerprint_dir: Path):
"""
Initializes the EncryptionManager with keys for both new (AES-GCM)
and legacy (Fernet) encryption formats.
Parameters:
encryption_key (bytes): A base64-encoded key.
fingerprint_dir (Path): The directory corresponding to the fingerprint.
"""
self.fingerprint_dir = fingerprint_dir
self.parent_seed_file = self.fingerprint_dir / "parent_seed.enc"
try:
if isinstance(encryption_key, str):
encryption_key = encryption_key.encode()
# (1) Keep both the legacy Fernet instance and the new AES-GCM cipher ready.
self.key_b64 = encryption_key
self.fernet = Fernet(self.key_b64)
self.key = base64.urlsafe_b64decode(self.key_b64)
self.cipher = AESGCM(self.key)
logger.debug(f"EncryptionManager initialized for {self.fingerprint_dir}")
except Exception as e:
logger.error(
f"Failed to initialize ciphers with provided encryption key: {e}",
exc_info=True,
)
raise
# Track user preference for handling legacy indexes
self._legacy_migrate_flag = True
self.last_migration_performed = False
def encrypt_data(self, data: bytes) -> bytes:
"""
(2) Encrypts data using the NEW AES-GCM format, prepending a version
header and the nonce. All new data will be in this format.
"""
try:
nonce = os.urandom(12) # 96-bit nonce is recommended for AES-GCM
ciphertext = self.cipher.encrypt(nonce, data, None)
return b"V2:" + nonce + ciphertext
except Exception as e:
logger.error(f"Failed to encrypt data: {e}", exc_info=True)
raise
def decrypt_data(
self, encrypted_data: bytes, context: Optional[str] = None
) -> bytes:
"""Decrypt ``encrypted_data`` handling legacy fallbacks.
Parameters
----------
encrypted_data:
The bytes to decrypt.
context:
Optional string describing what is being decrypted ("seed", "index", etc.)
for clearer error messages.
"""
ctx = f" {context}" if context else ""
try:
# Try the new V2 format first
if encrypted_data.startswith(b"V2:"):
try:
nonce = encrypted_data[3:15]
ciphertext = encrypted_data[15:]
if len(ciphertext) < 16:
logger.error("AES-GCM payload too short")
raise InvalidToken("AES-GCM payload too short")
return self.cipher.decrypt(nonce, ciphertext, None)
except InvalidTag as e:
logger.debug(
"AES-GCM decryption failed: Invalid authentication tag."
)
try:
result = self.fernet.decrypt(encrypted_data[3:])
logger.warning(
"Legacy-format file had incorrect 'V2:' header; decrypted with Fernet"
)
return result
except InvalidToken:
msg = f"Failed to decrypt{ctx}: invalid key or corrupt file"
logger.error(msg)
raise InvalidToken(msg) from e
# If it's not V2, it must be the legacy Fernet format
else:
logger.warning("Data is in legacy Fernet format. Attempting migration.")
try:
return self.fernet.decrypt(encrypted_data)
except InvalidToken as e:
logger.error(
"Legacy Fernet decryption failed. Vault may be corrupt or key is incorrect."
)
raise e
except (InvalidToken, InvalidTag) as e:
if encrypted_data.startswith(b"V2:"):
# Already determined not to be legacy; re-raise
raise
if isinstance(e, InvalidToken) and str(e) == "AES-GCM payload too short":
raise
if not self._legacy_migrate_flag:
raise
logger.debug(f"Could not decrypt data{ctx}: {e}")
raise LegacyFormatRequiresMigrationError(context)
def decrypt_legacy(
self, encrypted_data: bytes, password: str, context: Optional[str] = None
) -> bytes:
"""Decrypt ``encrypted_data`` using legacy password-only key derivation."""
ctx = f" {context}" if context else ""
last_exc: Optional[Exception] = None
for iter_count in [50_000, 100_000]:
try:
legacy_key = _derive_legacy_key_from_password(
password, iterations=iter_count
)
legacy_mgr = EncryptionManager(legacy_key, self.fingerprint_dir)
legacy_mgr._legacy_migrate_flag = False
result = legacy_mgr.decrypt_data(encrypted_data, context=context)
try: # record iteration count for future runs
from .vault import Vault
from .config_manager import ConfigManager
cfg_mgr = ConfigManager(
Vault(self, self.fingerprint_dir), self.fingerprint_dir
)
cfg_mgr.set_kdf_iterations(iter_count)
except Exception: # pragma: no cover - best effort
logger.error(
"Failed to record PBKDF2 iteration count in config",
exc_info=True,
)
logger.warning(
"Data decrypted using legacy password-only key derivation."
)
return result
except Exception as e2: # pragma: no cover - try next iteration
last_exc = e2
logger.error(f"Failed legacy decryption attempt: {last_exc}", exc_info=True)
raise InvalidToken(
f"Could not decrypt{ctx} with any available method."
) from last_exc
# --- All functions below this point now use the smart `decrypt_data` method ---
def resolve_relative_path(self, relative_path: Path) -> Path:
"""Resolve ``relative_path`` within ``fingerprint_dir`` and validate it.
Parameters
----------
relative_path:
The user-supplied path relative to ``fingerprint_dir``.
Returns
-------
Path
The normalized absolute path inside ``fingerprint_dir``.
Raises
------
ValueError
If the resulting path is absolute or escapes ``fingerprint_dir``.
"""
candidate = (self.fingerprint_dir / relative_path).resolve()
if not candidate.is_relative_to(self.fingerprint_dir.resolve()):
raise ValueError("Invalid path outside fingerprint directory")
return candidate
def encrypt_parent_seed(
self, parent_seed: str, kdf: Optional[KdfConfig] = None
) -> None:
"""Encrypts and saves the parent seed to 'parent_seed.enc'."""
data = parent_seed.encode("utf-8")
self.encrypt_and_save_file(data, self.parent_seed_file, kdf=kdf)
logger.info(f"Parent seed encrypted and saved to '{self.parent_seed_file}'.")
def decrypt_parent_seed(self) -> str:
"""Decrypts and returns the parent seed, handling migration."""
with exclusive_lock(self.parent_seed_file) as fh:
fh.seek(0)
blob = fh.read()
kdf, encrypted_data = self._deserialize(blob)
is_legacy = not encrypted_data.startswith(b"V2:")
decrypted_data = self.decrypt_data(encrypted_data, context="seed")
if is_legacy:
logger.info("Parent seed was in legacy format. Re-encrypting to V2 format.")
self.encrypt_parent_seed(decrypted_data.decode("utf-8").strip(), kdf=kdf)
return decrypted_data.decode("utf-8").strip()
def _serialize(self, kdf: KdfConfig, ciphertext: bytes) -> bytes:
payload = {"kdf": asdict(kdf), "ct": base64.b64encode(ciphertext).decode()}
if USE_ORJSON:
return json_lib.dumps(payload)
return json_lib.dumps(payload, separators=(",", ":")).encode("utf-8")
def _deserialize(self, blob: bytes) -> Tuple[KdfConfig, bytes]:
"""Return ``(KdfConfig, ciphertext)`` from serialized *blob*.
Legacy files stored the raw ciphertext without a JSON wrapper. If
decoding the wrapper fails, treat ``blob`` as the ciphertext and return
a default HKDF configuration.
"""
try:
if USE_ORJSON:
obj = json_lib.loads(blob)
else:
obj = json_lib.loads(blob.decode("utf-8"))
kdf = KdfConfig(**obj.get("kdf", {}))
ct_b64 = obj.get("ct", "")
ciphertext = base64.b64decode(ct_b64)
if ciphertext:
return kdf, ciphertext
except Exception: # pragma: no cover - fall back to legacy path
pass
# Legacy format: ``blob`` already contains the ciphertext
return (
KdfConfig(name="hkdf", version=CURRENT_KDF_VERSION, params={}, salt_b64=""),
blob,
)
def encrypt_and_save_file(
self, data: bytes, relative_path: Path, *, kdf: Optional[KdfConfig] = None
) -> None:
if kdf is None:
kdf = KdfConfig()
file_path = self.resolve_relative_path(relative_path)
file_path.parent.mkdir(parents=True, exist_ok=True)
encrypted_data = self.encrypt_data(data)
payload = self._serialize(kdf, encrypted_data)
with exclusive_lock(file_path) as fh:
fh.seek(0)
fh.truncate()
fh.write(payload)
fh.flush()
os.fsync(fh.fileno())
os.chmod(file_path, 0o600)
def decrypt_file(self, relative_path: Path) -> bytes:
file_path = self.resolve_relative_path(relative_path)
with exclusive_lock(file_path) as fh:
fh.seek(0)
blob = fh.read()
_, encrypted_data = self._deserialize(blob)
return self.decrypt_data(encrypted_data, context=str(relative_path))
def get_file_kdf(self, relative_path: Path) -> KdfConfig:
file_path = self.resolve_relative_path(relative_path)
with exclusive_lock(file_path) as fh:
fh.seek(0)
blob = fh.read()
kdf, _ = self._deserialize(blob)
return kdf
def save_json_data(
self,
data: dict,
relative_path: Optional[Path] = None,
*,
kdf: Optional[KdfConfig] = None,
) -> None:
if relative_path is None:
relative_path = Path("seedpass_entries_db.json.enc")
if USE_ORJSON:
json_data = json_lib.dumps(data)
else:
json_data = json_lib.dumps(data, separators=(",", ":")).encode("utf-8")
self.encrypt_and_save_file(json_data, relative_path, kdf=kdf)
logger.debug(f"JSON data encrypted and saved to '{relative_path}'.")
def load_json_data(
self, relative_path: Optional[Path] = None, *, return_kdf: bool = False
) -> dict | Tuple[dict, KdfConfig]:
"""
Loads and decrypts JSON data, automatically migrating and re-saving
if it's in the legacy format.
"""
if relative_path is None:
relative_path = Path("seedpass_entries_db.json.enc")
file_path = self.resolve_relative_path(relative_path)
if not file_path.exists():
empty: dict = {"entries": {}}
if return_kdf:
return empty, KdfConfig(
name="hkdf", version=CURRENT_KDF_VERSION, params={}, salt_b64=""
)
return empty
with exclusive_lock(file_path) as fh:
fh.seek(0)
blob = fh.read()
kdf, encrypted_data = self._deserialize(blob)
is_legacy = not encrypted_data.startswith(b"V2:")
self.last_migration_performed = False
try:
decrypted_data = self.decrypt_data(
encrypted_data, context=str(relative_path)
)
if USE_ORJSON:
data = json_lib.loads(decrypted_data)
else:
data = json_lib.loads(decrypted_data.decode("utf-8"))
# If it was a legacy file, re-save it in the new format now
if is_legacy and self._legacy_migrate_flag:
logger.info(f"Migrating and re-saving legacy vault file: {file_path}")
self.save_json_data(data, relative_path, kdf=kdf)
self.update_checksum(relative_path)
self.last_migration_performed = True
if return_kdf:
return data, kdf
return data
except (InvalidToken, InvalidTag, JSONDecodeError) as e:
logger.error(
f"FATAL: Could not decrypt or parse data from {file_path}: {e}",
exc_info=True,
)
raise
def get_encrypted_index(self) -> Optional[bytes]:
relative_path = Path("seedpass_entries_db.json.enc")
file_path = self.resolve_relative_path(relative_path)
if not file_path.exists():
return None
with exclusive_lock(file_path) as fh:
fh.seek(0)
return fh.read()
def decrypt_and_save_index_from_nostr(
self,
encrypted_data: bytes,
relative_path: Optional[Path] = None,
*,
strict: bool = True,
merge: bool = False,
) -> bool:
"""Decrypts data from Nostr and saves it.
Parameters
----------
encrypted_data:
The payload downloaded from Nostr.
relative_path:
Destination filename under the profile directory.
strict:
When ``True`` (default) re-raise any decryption error. When ``False``
return ``False`` if decryption fails.
"""
if relative_path is None:
relative_path = Path("seedpass_entries_db.json.enc")
kdf, ciphertext = self._deserialize(encrypted_data)
is_legacy = not ciphertext.startswith(b"V2:")
self.last_migration_performed = False
def _process(decrypted: bytes) -> dict:
if USE_ORJSON:
data = json_lib.loads(decrypted)
else:
data = json_lib.loads(decrypted.decode("utf-8"))
existing_file = self.resolve_relative_path(relative_path)
if merge and existing_file.exists():
current = self.load_json_data(relative_path)
current_entries = current.get("entries", {})
for idx, entry in data.get("entries", {}).items():
cur_ts = current_entries.get(idx, {}).get("modified_ts", 0)
new_ts = entry.get("modified_ts", 0)
if idx not in current_entries or new_ts >= cur_ts:
current_entries[idx] = entry
current["entries"] = current_entries
if "schema_version" in data:
current["schema_version"] = max(
current.get("schema_version", 0), data.get("schema_version", 0)
)
data = current
return data
try:
decrypted_data = self.decrypt_data(ciphertext, context=str(relative_path))
data = _process(decrypted_data)
self.save_json_data(data, relative_path, kdf=kdf)
self.update_checksum(relative_path)
logger.info("Index file from Nostr was processed and saved successfully.")
self.last_migration_performed = is_legacy
return True
except (InvalidToken, LegacyFormatRequiresMigrationError):
try:
password = prompt_existing_password(
"Enter your master password for legacy decryption: "
)
decrypted_data = self.decrypt_legacy(
ciphertext, password, context=str(relative_path)
)
data = _process(decrypted_data)
self.save_json_data(data, relative_path, kdf=kdf)
self.update_checksum(relative_path)
logger.warning(
"Index decrypted using legacy password-only key derivation."
)
print(
colored(
"Warning: index decrypted with legacy key; it will be re-encrypted.",
"yellow",
)
)
self.last_migration_performed = True
return True
except Exception as e2:
if strict:
logger.error(
f"Failed legacy decryption attempt: {e2}",
exc_info=True,
)
print(
colored(
f"Error: Failed to decrypt and save data from Nostr: {e2}",
"red",
)
)
raise
logger.warning(f"Failed to decrypt index from Nostr: {e2}")
return False
except Exception as e: # pragma: no cover - error handling
if strict:
logger.error(
f"Failed to decrypt and save data from Nostr: {e}",
exc_info=True,
)
print(
colored(
f"Error: Failed to decrypt and save data from Nostr: {e}",
"red",
)
)
raise
logger.warning(f"Failed to decrypt index from Nostr: {e}")
return False
def update_checksum(self, relative_path: Optional[Path] = None) -> None:
"""Updates the checksum file for the specified file."""
if relative_path is None:
relative_path = Path("seedpass_entries_db.json.enc")
file_path = self.resolve_relative_path(relative_path)
if not file_path.exists():
return
try:
with exclusive_lock(file_path) as fh:
fh.seek(0)
encrypted_bytes = fh.read()
checksum = hashlib.sha256(encrypted_bytes).hexdigest()
# Build checksum path by stripping both `.json` and `.enc`
checksum_base = file_path.with_suffix("").with_suffix("")
checksum_file = checksum_base.parent / f"{checksum_base.name}_checksum.txt"
# Remove legacy checksum file if present
legacy_checksum = file_path.parent / f"{file_path.stem}_checksum.txt"
if legacy_checksum != checksum_file and legacy_checksum.exists():
try:
legacy_checksum.unlink()
except Exception:
logger.warning(
f"Could not remove legacy checksum file '{legacy_checksum}'",
exc_info=True,
)
with exclusive_lock(checksum_file) as fh:
fh.seek(0)
fh.truncate()
fh.write(checksum.encode("utf-8"))
fh.flush()
os.fsync(fh.fileno())
os.chmod(checksum_file, 0o600)
except Exception as e:
logger.error(
f"Failed to update checksum for '{relative_path}': {e}",
exc_info=True,
)
raise
def validate_seed(self, seed_phrase: str) -> tuple[bool, Optional[str]]:
"""Validate a BIP-39 mnemonic.
Returns a tuple of ``(is_valid, error_message)`` where ``error_message``
is ``None`` when the mnemonic is valid.
"""
try:
if Mnemonic("english").check(seed_phrase):
logger.debug("Seed phrase validated successfully.")
return True, None
logger.error("Seed phrase failed BIP-39 validation.")
return False, "Invalid seed phrase."
except Exception as e:
logger.error(f"Error validating seed phrase: {e}", exc_info=True)
return False, f"Failed to validate seed phrase: {e}"
def derive_seed_from_mnemonic(self, mnemonic: str, passphrase: str = "") -> bytes:
try:
if not isinstance(mnemonic, str):
if isinstance(mnemonic, list):
mnemonic = " ".join(mnemonic)
else:
mnemonic = str(mnemonic)
if not isinstance(mnemonic, str):
raise TypeError("Mnemonic must be a string after conversion")
from bip_utils import Bip39SeedGenerator
seed = Bip39SeedGenerator(mnemonic).Generate(passphrase)
logger.debug("Seed derived successfully from mnemonic.")
return seed
except Exception as e:
logger.error(f"Failed to derive seed from mnemonic: {e}", exc_info=True)
print(colored(f"Error: Failed to derive seed from mnemonic: {e}", "red"))
raise