Merge pull request #120 from PR0M3TH3AN/codex/add-checksum-validation-to-backup-process

Add portable backup checksum verification
This commit is contained in:
thePR0M3TH3AN
2025-07-01 21:45:17 -04:00
committed by GitHub
5 changed files with 71 additions and 11 deletions

View File

@@ -5,7 +5,6 @@ from __future__ import annotations
import base64
import json
import hashlib
import logging
import os
import time
@@ -22,6 +21,7 @@ from utils.key_derivation import (
)
from utils.password_prompt import prompt_existing_password
from password_manager.encryption import EncryptionManager
from utils.checksum import json_checksum, canonical_json_dumps
logger = logging.getLogger(__name__)
@@ -72,10 +72,10 @@ def export_backup(
key = _derive_export_key(seed, mode, password)
enc_mgr = EncryptionManager(key, vault.fingerprint_dir)
payload_bytes = enc_mgr.encrypt_data(
json.dumps(index_data, indent=4).encode("utf-8")
)
checksum = hashlib.sha256(payload_bytes).hexdigest()
canonical = canonical_json_dumps(index_data)
payload_bytes = enc_mgr.encrypt_data(canonical.encode("utf-8"))
checksum = json_checksum(index_data)
wrapper = {
"format_version": FORMAT_VERSION,
@@ -122,9 +122,6 @@ def import_backup(
mode = PortableMode(wrapper.get("encryption_mode", PortableMode.SEED_ONLY.value))
payload = base64.b64decode(wrapper["payload"])
checksum = hashlib.sha256(payload).hexdigest()
if checksum != wrapper.get("checksum"):
raise ValueError("Checksum mismatch")
seed = vault.encryption_manager.decrypt_parent_seed()
password = None
@@ -136,5 +133,9 @@ def import_backup(
index_bytes = enc_mgr.decrypt_data(payload)
index = json.loads(index_bytes.decode("utf-8"))
checksum = json_checksum(index)
if checksum != wrapper.get("checksum"):
raise ValueError("Checksum mismatch")
backup_manager.create_backup()
vault.save_index(index)

View File

@@ -1,9 +1,18 @@
import hashlib
import json
from pathlib import Path
from utils import checksum
def test_json_checksum():
data = {"b": 1, "a": 2}
expected = hashlib.sha256(
json.dumps(data, sort_keys=True, separators=(",", ":")).encode()
).hexdigest()
assert checksum.json_checksum(data) == expected
def test_calculate_checksum(tmp_path):
file = tmp_path / "data.txt"
content = "hello world"

View File

@@ -57,6 +57,9 @@ def test_round_trip_across_modes(monkeypatch):
assert vault.load_index()["pw"] == data["pw"]
from cryptography.fernet import InvalidToken
def test_corruption_detection(monkeypatch):
with TemporaryDirectory() as td:
tmp = Path(td)
@@ -75,7 +78,7 @@ def test_corruption_detection(monkeypatch):
content["payload"] = base64.b64encode(payload).decode()
path.write_text(json.dumps(content))
with pytest.raises(ValueError):
with pytest.raises(InvalidToken):
import_backup(vault, backup, path)
@@ -115,3 +118,31 @@ def test_import_over_existing(monkeypatch):
import_backup(vault, backup, path)
loaded = vault.load_index()
assert loaded["v"] == 1
def test_checksum_mismatch_detection(monkeypatch):
with TemporaryDirectory() as td:
tmp = Path(td)
vault, backup = setup_vault(tmp)
vault.save_index({"a": 1})
monkeypatch.setattr(
"password_manager.portable_backup.prompt_existing_password",
lambda *_a, **_k: PASSWORD,
)
path = export_backup(vault, backup, PortableMode.SEED_ONLY)
wrapper = json.loads(path.read_text())
payload = base64.b64decode(wrapper["payload"])
key = derive_index_key(SEED, PASSWORD, EncryptionMode.SEED_ONLY)
enc_mgr = EncryptionManager(key, tmp)
data = json.loads(enc_mgr.decrypt_data(payload).decode())
data["a"] = 2
mod_canon = json.dumps(data, sort_keys=True, separators=(",", ":"))
new_payload = enc_mgr.encrypt_data(mod_canon.encode())
wrapper["payload"] = base64.b64encode(new_payload).decode()
path.write_text(json.dumps(wrapper))
with pytest.raises(ValueError):
import_backup(vault, backup, path)

View File

@@ -14,7 +14,12 @@ try:
EncryptionMode,
DEFAULT_ENCRYPTION_MODE,
)
from .checksum import calculate_checksum, verify_checksum
from .checksum import (
calculate_checksum,
verify_checksum,
json_checksum,
canonical_json_dumps,
)
from .password_prompt import prompt_for_password
if logger.isEnabledFor(logging.DEBUG):
@@ -31,6 +36,8 @@ __all__ = [
"DEFAULT_ENCRYPTION_MODE",
"calculate_checksum",
"verify_checksum",
"json_checksum",
"canonical_json_dumps",
"exclusive_lock",
"shared_lock",
"prompt_for_password",

View File

@@ -14,8 +14,9 @@ import hashlib
import logging
import sys
import os
import json
import traceback
from typing import Optional
from typing import Optional, Any
from termcolor import colored
@@ -25,6 +26,17 @@ from constants import APP_DIR, SCRIPT_CHECKSUM_FILE
logger = logging.getLogger(__name__)
def canonical_json_dumps(data: Any) -> str:
"""Serialize ``data`` into a canonical JSON string."""
return json.dumps(data, sort_keys=True, separators=(",", ":"))
def json_checksum(data: Any) -> str:
"""Return SHA-256 checksum of canonical JSON serialization of ``data``."""
canon = canonical_json_dumps(data)
return hashlib.sha256(canon.encode("utf-8")).hexdigest()
def calculate_checksum(file_path: str) -> Optional[str]:
"""
Calculates the SHA-256 checksum of the given file.