Merge pull request #120 from PR0M3TH3AN/codex/add-checksum-validation-to-backup-process

Add portable backup checksum verification
This commit is contained in:
thePR0M3TH3AN
2025-07-01 21:45:17 -04:00
committed by GitHub
5 changed files with 71 additions and 11 deletions

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import base64 import base64
import json import json
import hashlib
import logging import logging
import os import os
import time import time
@@ -22,6 +21,7 @@ from utils.key_derivation import (
) )
from utils.password_prompt import prompt_existing_password from utils.password_prompt import prompt_existing_password
from password_manager.encryption import EncryptionManager from password_manager.encryption import EncryptionManager
from utils.checksum import json_checksum, canonical_json_dumps
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -72,10 +72,10 @@ def export_backup(
key = _derive_export_key(seed, mode, password) key = _derive_export_key(seed, mode, password)
enc_mgr = EncryptionManager(key, vault.fingerprint_dir) enc_mgr = EncryptionManager(key, vault.fingerprint_dir)
payload_bytes = enc_mgr.encrypt_data(
json.dumps(index_data, indent=4).encode("utf-8") canonical = canonical_json_dumps(index_data)
) payload_bytes = enc_mgr.encrypt_data(canonical.encode("utf-8"))
checksum = hashlib.sha256(payload_bytes).hexdigest() checksum = json_checksum(index_data)
wrapper = { wrapper = {
"format_version": FORMAT_VERSION, "format_version": FORMAT_VERSION,
@@ -122,9 +122,6 @@ def import_backup(
mode = PortableMode(wrapper.get("encryption_mode", PortableMode.SEED_ONLY.value)) mode = PortableMode(wrapper.get("encryption_mode", PortableMode.SEED_ONLY.value))
payload = base64.b64decode(wrapper["payload"]) payload = base64.b64decode(wrapper["payload"])
checksum = hashlib.sha256(payload).hexdigest()
if checksum != wrapper.get("checksum"):
raise ValueError("Checksum mismatch")
seed = vault.encryption_manager.decrypt_parent_seed() seed = vault.encryption_manager.decrypt_parent_seed()
password = None password = None
@@ -136,5 +133,9 @@ def import_backup(
index_bytes = enc_mgr.decrypt_data(payload) index_bytes = enc_mgr.decrypt_data(payload)
index = json.loads(index_bytes.decode("utf-8")) index = json.loads(index_bytes.decode("utf-8"))
checksum = json_checksum(index)
if checksum != wrapper.get("checksum"):
raise ValueError("Checksum mismatch")
backup_manager.create_backup() backup_manager.create_backup()
vault.save_index(index) vault.save_index(index)

View File

@@ -1,9 +1,18 @@
import hashlib import hashlib
import json
from pathlib import Path from pathlib import Path
from utils import checksum from utils import checksum
def test_json_checksum():
data = {"b": 1, "a": 2}
expected = hashlib.sha256(
json.dumps(data, sort_keys=True, separators=(",", ":")).encode()
).hexdigest()
assert checksum.json_checksum(data) == expected
def test_calculate_checksum(tmp_path): def test_calculate_checksum(tmp_path):
file = tmp_path / "data.txt" file = tmp_path / "data.txt"
content = "hello world" content = "hello world"

View File

@@ -57,6 +57,9 @@ def test_round_trip_across_modes(monkeypatch):
assert vault.load_index()["pw"] == data["pw"] assert vault.load_index()["pw"] == data["pw"]
from cryptography.fernet import InvalidToken
def test_corruption_detection(monkeypatch): def test_corruption_detection(monkeypatch):
with TemporaryDirectory() as td: with TemporaryDirectory() as td:
tmp = Path(td) tmp = Path(td)
@@ -75,7 +78,7 @@ def test_corruption_detection(monkeypatch):
content["payload"] = base64.b64encode(payload).decode() content["payload"] = base64.b64encode(payload).decode()
path.write_text(json.dumps(content)) path.write_text(json.dumps(content))
with pytest.raises(ValueError): with pytest.raises(InvalidToken):
import_backup(vault, backup, path) import_backup(vault, backup, path)
@@ -115,3 +118,31 @@ def test_import_over_existing(monkeypatch):
import_backup(vault, backup, path) import_backup(vault, backup, path)
loaded = vault.load_index() loaded = vault.load_index()
assert loaded["v"] == 1 assert loaded["v"] == 1
def test_checksum_mismatch_detection(monkeypatch):
with TemporaryDirectory() as td:
tmp = Path(td)
vault, backup = setup_vault(tmp)
vault.save_index({"a": 1})
monkeypatch.setattr(
"password_manager.portable_backup.prompt_existing_password",
lambda *_a, **_k: PASSWORD,
)
path = export_backup(vault, backup, PortableMode.SEED_ONLY)
wrapper = json.loads(path.read_text())
payload = base64.b64decode(wrapper["payload"])
key = derive_index_key(SEED, PASSWORD, EncryptionMode.SEED_ONLY)
enc_mgr = EncryptionManager(key, tmp)
data = json.loads(enc_mgr.decrypt_data(payload).decode())
data["a"] = 2
mod_canon = json.dumps(data, sort_keys=True, separators=(",", ":"))
new_payload = enc_mgr.encrypt_data(mod_canon.encode())
wrapper["payload"] = base64.b64encode(new_payload).decode()
path.write_text(json.dumps(wrapper))
with pytest.raises(ValueError):
import_backup(vault, backup, path)

View File

@@ -14,7 +14,12 @@ try:
EncryptionMode, EncryptionMode,
DEFAULT_ENCRYPTION_MODE, DEFAULT_ENCRYPTION_MODE,
) )
from .checksum import calculate_checksum, verify_checksum from .checksum import (
calculate_checksum,
verify_checksum,
json_checksum,
canonical_json_dumps,
)
from .password_prompt import prompt_for_password from .password_prompt import prompt_for_password
if logger.isEnabledFor(logging.DEBUG): if logger.isEnabledFor(logging.DEBUG):
@@ -31,6 +36,8 @@ __all__ = [
"DEFAULT_ENCRYPTION_MODE", "DEFAULT_ENCRYPTION_MODE",
"calculate_checksum", "calculate_checksum",
"verify_checksum", "verify_checksum",
"json_checksum",
"canonical_json_dumps",
"exclusive_lock", "exclusive_lock",
"shared_lock", "shared_lock",
"prompt_for_password", "prompt_for_password",

View File

@@ -14,8 +14,9 @@ import hashlib
import logging import logging
import sys import sys
import os import os
import json
import traceback import traceback
from typing import Optional from typing import Optional, Any
from termcolor import colored from termcolor import colored
@@ -25,6 +26,17 @@ from constants import APP_DIR, SCRIPT_CHECKSUM_FILE
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def canonical_json_dumps(data: Any) -> str:
"""Serialize ``data`` into a canonical JSON string."""
return json.dumps(data, sort_keys=True, separators=(",", ":"))
def json_checksum(data: Any) -> str:
"""Return SHA-256 checksum of canonical JSON serialization of ``data``."""
canon = canonical_json_dumps(data)
return hashlib.sha256(canon.encode("utf-8")).hexdigest()
def calculate_checksum(file_path: str) -> Optional[str]: def calculate_checksum(file_path: str) -> Optional[str]:
""" """
Calculates the SHA-256 checksum of the given file. Calculates the SHA-256 checksum of the given file.