From c10e2380e5f628d38e39c68513592da055cc2bf2 Mon Sep 17 00:00:00 2001 From: thePR0M3TH3AN <53631862+PR0M3TH3AN@users.noreply.github.com> Date: Tue, 1 Jul 2025 21:43:52 -0400 Subject: [PATCH] Add checksum verification for portable backups --- src/password_manager/portable_backup.py | 17 +++++++------ src/tests/test_checksum_utils.py | 9 +++++++ src/tests/test_portable_backup.py | 33 ++++++++++++++++++++++++- src/utils/__init__.py | 9 ++++++- src/utils/checksum.py | 14 ++++++++++- 5 files changed, 71 insertions(+), 11 deletions(-) diff --git a/src/password_manager/portable_backup.py b/src/password_manager/portable_backup.py index d424cf9..b027545 100644 --- a/src/password_manager/portable_backup.py +++ b/src/password_manager/portable_backup.py @@ -5,7 +5,6 @@ from __future__ import annotations import base64 import json -import hashlib import logging import os import time @@ -22,6 +21,7 @@ from utils.key_derivation import ( ) from utils.password_prompt import prompt_existing_password from password_manager.encryption import EncryptionManager +from utils.checksum import json_checksum, canonical_json_dumps logger = logging.getLogger(__name__) @@ -72,10 +72,10 @@ def export_backup( key = _derive_export_key(seed, mode, password) enc_mgr = EncryptionManager(key, vault.fingerprint_dir) - payload_bytes = enc_mgr.encrypt_data( - json.dumps(index_data, indent=4).encode("utf-8") - ) - checksum = hashlib.sha256(payload_bytes).hexdigest() + + canonical = canonical_json_dumps(index_data) + payload_bytes = enc_mgr.encrypt_data(canonical.encode("utf-8")) + checksum = json_checksum(index_data) wrapper = { "format_version": FORMAT_VERSION, @@ -122,9 +122,6 @@ def import_backup( mode = PortableMode(wrapper.get("encryption_mode", PortableMode.SEED_ONLY.value)) payload = base64.b64decode(wrapper["payload"]) - checksum = hashlib.sha256(payload).hexdigest() - if checksum != wrapper.get("checksum"): - raise ValueError("Checksum mismatch") seed = vault.encryption_manager.decrypt_parent_seed() password = None @@ -136,5 +133,9 @@ def import_backup( index_bytes = enc_mgr.decrypt_data(payload) index = json.loads(index_bytes.decode("utf-8")) + checksum = json_checksum(index) + if checksum != wrapper.get("checksum"): + raise ValueError("Checksum mismatch") + backup_manager.create_backup() vault.save_index(index) diff --git a/src/tests/test_checksum_utils.py b/src/tests/test_checksum_utils.py index e30643d..816462a 100644 --- a/src/tests/test_checksum_utils.py +++ b/src/tests/test_checksum_utils.py @@ -1,9 +1,18 @@ import hashlib +import json from pathlib import Path from utils import checksum +def test_json_checksum(): + data = {"b": 1, "a": 2} + expected = hashlib.sha256( + json.dumps(data, sort_keys=True, separators=(",", ":")).encode() + ).hexdigest() + assert checksum.json_checksum(data) == expected + + def test_calculate_checksum(tmp_path): file = tmp_path / "data.txt" content = "hello world" diff --git a/src/tests/test_portable_backup.py b/src/tests/test_portable_backup.py index 6b9c665..ba78dd1 100644 --- a/src/tests/test_portable_backup.py +++ b/src/tests/test_portable_backup.py @@ -57,6 +57,9 @@ def test_round_trip_across_modes(monkeypatch): assert vault.load_index()["pw"] == data["pw"] +from cryptography.fernet import InvalidToken + + def test_corruption_detection(monkeypatch): with TemporaryDirectory() as td: tmp = Path(td) @@ -75,7 +78,7 @@ def test_corruption_detection(monkeypatch): content["payload"] = base64.b64encode(payload).decode() path.write_text(json.dumps(content)) - with pytest.raises(ValueError): + with pytest.raises(InvalidToken): import_backup(vault, backup, path) @@ -115,3 +118,31 @@ def test_import_over_existing(monkeypatch): import_backup(vault, backup, path) loaded = vault.load_index() assert loaded["v"] == 1 + + +def test_checksum_mismatch_detection(monkeypatch): + with TemporaryDirectory() as td: + tmp = Path(td) + vault, backup = setup_vault(tmp) + vault.save_index({"a": 1}) + + monkeypatch.setattr( + "password_manager.portable_backup.prompt_existing_password", + lambda *_a, **_k: PASSWORD, + ) + + path = export_backup(vault, backup, PortableMode.SEED_ONLY) + + wrapper = json.loads(path.read_text()) + payload = base64.b64decode(wrapper["payload"]) + key = derive_index_key(SEED, PASSWORD, EncryptionMode.SEED_ONLY) + enc_mgr = EncryptionManager(key, tmp) + data = json.loads(enc_mgr.decrypt_data(payload).decode()) + data["a"] = 2 + mod_canon = json.dumps(data, sort_keys=True, separators=(",", ":")) + new_payload = enc_mgr.encrypt_data(mod_canon.encode()) + wrapper["payload"] = base64.b64encode(new_payload).decode() + path.write_text(json.dumps(wrapper)) + + with pytest.raises(ValueError): + import_backup(vault, backup, path) diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 6e21714..eb5ba69 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -14,7 +14,12 @@ try: EncryptionMode, DEFAULT_ENCRYPTION_MODE, ) - from .checksum import calculate_checksum, verify_checksum + from .checksum import ( + calculate_checksum, + verify_checksum, + json_checksum, + canonical_json_dumps, + ) from .password_prompt import prompt_for_password if logger.isEnabledFor(logging.DEBUG): @@ -31,6 +36,8 @@ __all__ = [ "DEFAULT_ENCRYPTION_MODE", "calculate_checksum", "verify_checksum", + "json_checksum", + "canonical_json_dumps", "exclusive_lock", "shared_lock", "prompt_for_password", diff --git a/src/utils/checksum.py b/src/utils/checksum.py index 60278f6..983e060 100644 --- a/src/utils/checksum.py +++ b/src/utils/checksum.py @@ -14,8 +14,9 @@ import hashlib import logging import sys import os +import json import traceback -from typing import Optional +from typing import Optional, Any from termcolor import colored @@ -25,6 +26,17 @@ from constants import APP_DIR, SCRIPT_CHECKSUM_FILE logger = logging.getLogger(__name__) +def canonical_json_dumps(data: Any) -> str: + """Serialize ``data`` into a canonical JSON string.""" + return json.dumps(data, sort_keys=True, separators=(",", ":")) + + +def json_checksum(data: Any) -> str: + """Return SHA-256 checksum of canonical JSON serialization of ``data``.""" + canon = canonical_json_dumps(data) + return hashlib.sha256(canon.encode("utf-8")).hexdigest() + + def calculate_checksum(file_path: str) -> Optional[str]: """ Calculates the SHA-256 checksum of the given file.