mirror of
https://github.com/PR0M3TH3AN/SeedPass.git
synced 2025-09-09 15:58:48 +00:00
Add atomic write utility and tests
This commit is contained in:
@@ -32,6 +32,7 @@ from utils import (
|
|||||||
pause,
|
pause,
|
||||||
clear_header_with_notification,
|
clear_header_with_notification,
|
||||||
)
|
)
|
||||||
|
from utils.atomic_write import atomic_write
|
||||||
import queue
|
import queue
|
||||||
from local_bip85.bip85 import Bip85Error
|
from local_bip85.bip85 import Bip85Error
|
||||||
|
|
||||||
@@ -667,8 +668,7 @@ def handle_set_additional_backup_location(pm: PasswordManager) -> None:
|
|||||||
path = Path(value).expanduser()
|
path = Path(value).expanduser()
|
||||||
path.mkdir(parents=True, exist_ok=True)
|
path.mkdir(parents=True, exist_ok=True)
|
||||||
test_file = path / ".seedpass_write_test"
|
test_file = path / ".seedpass_write_test"
|
||||||
with open(test_file, "w") as f:
|
atomic_write(test_file, lambda f: f.write("test"))
|
||||||
f.write("test")
|
|
||||||
test_file.unlink()
|
test_file.unlink()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(colored(f"Path not writable: {e}", "red"))
|
print(colored(f"Path not writable: {e}", "red"))
|
||||||
|
@@ -37,6 +37,7 @@ from .entry_types import EntryType
|
|||||||
from .totp import TotpManager
|
from .totp import TotpManager
|
||||||
from utils.fingerprint import generate_fingerprint
|
from utils.fingerprint import generate_fingerprint
|
||||||
from utils.checksum import canonical_json_dumps
|
from utils.checksum import canonical_json_dumps
|
||||||
|
from utils.atomic_write import atomic_write
|
||||||
from utils.key_validation import (
|
from utils.key_validation import (
|
||||||
validate_totp_secret,
|
validate_totp_secret,
|
||||||
validate_ssh_key_pair,
|
validate_ssh_key_pair,
|
||||||
@@ -1312,8 +1313,7 @@ class EntryManager:
|
|||||||
# The checksum file path already includes the fingerprint directory
|
# The checksum file path already includes the fingerprint directory
|
||||||
checksum_path = self.checksum_file
|
checksum_path = self.checksum_file
|
||||||
|
|
||||||
with open(checksum_path, "w") as f:
|
atomic_write(checksum_path, lambda f: f.write(checksum))
|
||||||
f.write(checksum)
|
|
||||||
|
|
||||||
logger.debug(f"Checksum updated and written to '{checksum_path}'.")
|
logger.debug(f"Checksum updated and written to '{checksum_path}'.")
|
||||||
print(colored(f"[+] Checksum updated successfully.", "green"))
|
print(colored(f"[+] Checksum updated successfully.", "green"))
|
||||||
|
@@ -66,6 +66,7 @@ from utils.terminal_utils import (
|
|||||||
clear_header_with_notification,
|
clear_header_with_notification,
|
||||||
)
|
)
|
||||||
from utils.fingerprint import generate_fingerprint
|
from utils.fingerprint import generate_fingerprint
|
||||||
|
from utils.atomic_write import atomic_write
|
||||||
from constants import MIN_HEALTHY_RELAYS
|
from constants import MIN_HEALTHY_RELAYS
|
||||||
from .migrations import LATEST_VERSION
|
from .migrations import LATEST_VERSION
|
||||||
|
|
||||||
@@ -4377,8 +4378,11 @@ class PasswordManager:
|
|||||||
else:
|
else:
|
||||||
# Fallback to legacy file method if config_manager unavailable
|
# Fallback to legacy file method if config_manager unavailable
|
||||||
hashed_password_file = self.fingerprint_dir / "hashed_password.enc"
|
hashed_password_file = self.fingerprint_dir / "hashed_password.enc"
|
||||||
with open(hashed_password_file, "wb") as f:
|
atomic_write(
|
||||||
f.write(hashed.encode())
|
hashed_password_file,
|
||||||
|
lambda f: f.write(hashed.encode()),
|
||||||
|
mode="wb",
|
||||||
|
)
|
||||||
os.chmod(hashed_password_file, 0o600)
|
os.chmod(hashed_password_file, 0o600)
|
||||||
logging.info("User password hashed and stored successfully.")
|
logging.info("User password hashed and stored successfully.")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
@@ -4389,8 +4393,11 @@ class PasswordManager:
|
|||||||
self.config_manager.set_password_hash(hashed)
|
self.config_manager.set_password_hash(hashed)
|
||||||
else:
|
else:
|
||||||
hashed_password_file = self.fingerprint_dir / "hashed_password.enc"
|
hashed_password_file = self.fingerprint_dir / "hashed_password.enc"
|
||||||
with open(hashed_password_file, "wb") as f:
|
atomic_write(
|
||||||
f.write(hashed.encode())
|
hashed_password_file,
|
||||||
|
lambda f: f.write(hashed.encode()),
|
||||||
|
mode="wb",
|
||||||
|
)
|
||||||
os.chmod(hashed_password_file, 0o600)
|
os.chmod(hashed_password_file, 0o600)
|
||||||
logging.info(
|
logging.info(
|
||||||
"User password hashed and stored successfully (using alternative method)."
|
"User password hashed and stored successfully (using alternative method)."
|
||||||
|
30
src/tests/test_atomic_write.py
Normal file
30
src/tests/test_atomic_write.py
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
import json
|
||||||
|
from multiprocessing import Process
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from utils.atomic_write import atomic_write
|
||||||
|
|
||||||
|
|
||||||
|
def _writer(path: Path, content: dict, loops: int) -> None:
|
||||||
|
for _ in range(loops):
|
||||||
|
atomic_write(path, lambda f: json.dump(content, f), mode="w")
|
||||||
|
|
||||||
|
|
||||||
|
def test_atomic_write_concurrent(tmp_path: Path) -> None:
|
||||||
|
"""Concurrent writers should not leave partial files."""
|
||||||
|
|
||||||
|
file_path = tmp_path / "data.json"
|
||||||
|
contents = [{"proc": i} for i in range(5)]
|
||||||
|
|
||||||
|
procs = [
|
||||||
|
Process(target=_writer, args=(file_path, content, 50)) for content in contents
|
||||||
|
]
|
||||||
|
|
||||||
|
for p in procs:
|
||||||
|
p.start()
|
||||||
|
for p in procs:
|
||||||
|
p.join()
|
||||||
|
|
||||||
|
final_text = file_path.read_text()
|
||||||
|
final_obj = json.loads(final_text)
|
||||||
|
assert final_obj in contents
|
@@ -35,6 +35,7 @@ try:
|
|||||||
clear_and_print_fingerprint,
|
clear_and_print_fingerprint,
|
||||||
clear_header_with_notification,
|
clear_header_with_notification,
|
||||||
)
|
)
|
||||||
|
from .atomic_write import atomic_write
|
||||||
|
|
||||||
if logger.isEnabledFor(logging.DEBUG):
|
if logger.isEnabledFor(logging.DEBUG):
|
||||||
logger.info("Modules imported successfully.")
|
logger.info("Modules imported successfully.")
|
||||||
@@ -68,4 +69,5 @@ __all__ = [
|
|||||||
"clear_and_print_fingerprint",
|
"clear_and_print_fingerprint",
|
||||||
"clear_header_with_notification",
|
"clear_header_with_notification",
|
||||||
"pause",
|
"pause",
|
||||||
|
"atomic_write",
|
||||||
]
|
]
|
||||||
|
62
src/utils/atomic_write.py
Normal file
62
src/utils/atomic_write.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Utility helpers for performing atomic file writes.
|
||||||
|
|
||||||
|
This module provides a small helper function :func:`atomic_write` which
|
||||||
|
implements a simple pattern for writing files atomically. Data is written to a
|
||||||
|
temporary file in the same directory, flushed and synced to disk, and then
|
||||||
|
``os.replace`` is used to atomically move the temporary file into place.
|
||||||
|
|
||||||
|
The function accepts a callable ``write_func`` that receives the temporary file
|
||||||
|
object. This keeps the helper flexible enough to support both text and binary
|
||||||
|
writes and allows callers to perform complex serialisation steps (e.g. JSON
|
||||||
|
dumping) without exposing a partially written file to other processes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Any, IO
|
||||||
|
|
||||||
|
|
||||||
|
def atomic_write(
|
||||||
|
path: str | Path,
|
||||||
|
write_func: Callable[[IO[Any]], None],
|
||||||
|
*,
|
||||||
|
mode: str = "w",
|
||||||
|
**open_kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Write to ``path`` atomically using ``write_func``.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
path:
|
||||||
|
Destination file path.
|
||||||
|
write_func:
|
||||||
|
Callable that receives an open file object and performs the actual
|
||||||
|
write. The callable should not close the file.
|
||||||
|
mode:
|
||||||
|
File mode used when opening the temporary file. Defaults to ``"w"``.
|
||||||
|
**open_kwargs:
|
||||||
|
Additional keyword arguments passed to :func:`os.fdopen`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
dest = Path(path)
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
fd, tmp_path = tempfile.mkstemp(dir=str(dest.parent))
|
||||||
|
try:
|
||||||
|
with os.fdopen(fd, mode, **open_kwargs) as tmp_file:
|
||||||
|
write_func(tmp_file)
|
||||||
|
tmp_file.flush()
|
||||||
|
os.fsync(tmp_file.fileno())
|
||||||
|
os.replace(tmp_path, dest)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
except FileNotFoundError:
|
||||||
|
pass
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["atomic_write"]
|
@@ -21,6 +21,7 @@ from typing import Optional, Any
|
|||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from constants import APP_DIR, SCRIPT_CHECKSUM_FILE
|
from constants import APP_DIR, SCRIPT_CHECKSUM_FILE
|
||||||
|
from utils.atomic_write import atomic_write
|
||||||
|
|
||||||
# Instantiate the logger
|
# Instantiate the logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -121,8 +122,7 @@ def update_checksum(content: str, checksum_file_path: str) -> bool:
|
|||||||
hasher = hashlib.sha256()
|
hasher = hashlib.sha256()
|
||||||
hasher.update(content.encode("utf-8"))
|
hasher.update(content.encode("utf-8"))
|
||||||
new_checksum = hasher.hexdigest()
|
new_checksum = hasher.hexdigest()
|
||||||
with open(checksum_file_path, "w") as f:
|
atomic_write(checksum_file_path, lambda f: f.write(new_checksum))
|
||||||
f.write(new_checksum)
|
|
||||||
logging.debug(f"Updated checksum for '{checksum_file_path}' to: {new_checksum}")
|
logging.debug(f"Updated checksum for '{checksum_file_path}' to: {new_checksum}")
|
||||||
return True
|
return True
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -179,8 +179,7 @@ def initialize_checksum(file_path: str, checksum_file_path: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(checksum_file_path, "w") as f:
|
atomic_write(checksum_file_path, lambda f: f.write(checksum))
|
||||||
f.write(checksum)
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Initialized checksum file '{checksum_file_path}' with checksum: {checksum}"
|
f"Initialized checksum file '{checksum_file_path}' with checksum: {checksum}"
|
||||||
)
|
)
|
||||||
@@ -206,8 +205,7 @@ def update_checksum_file(file_path: str, checksum_file_path: str) -> bool:
|
|||||||
if checksum is None:
|
if checksum is None:
|
||||||
return False
|
return False
|
||||||
try:
|
try:
|
||||||
with open(checksum_file_path, "w") as f:
|
atomic_write(checksum_file_path, lambda f: f.write(checksum))
|
||||||
f.write(checksum)
|
|
||||||
logging.debug(
|
logging.debug(
|
||||||
f"Updated checksum for '{file_path}' to '{checksum}' at '{checksum_file_path}'."
|
f"Updated checksum for '{file_path}' to '{checksum}' at '{checksum_file_path}'."
|
||||||
)
|
)
|
||||||
|
@@ -9,6 +9,7 @@ from typing import List, Optional
|
|||||||
|
|
||||||
import shutil # Ensure shutil is imported if used within the class
|
import shutil # Ensure shutil is imported if used within the class
|
||||||
|
|
||||||
|
from utils.atomic_write import atomic_write
|
||||||
from utils.fingerprint import generate_fingerprint
|
from utils.fingerprint import generate_fingerprint
|
||||||
|
|
||||||
# Instantiate the logger
|
# Instantiate the logger
|
||||||
@@ -92,16 +93,15 @@ class FingerprintManager:
|
|||||||
Saves the current list of fingerprints to the fingerprints.json file.
|
Saves the current list of fingerprints to the fingerprints.json file.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with open(self.fingerprints_file, "w") as f:
|
data = {
|
||||||
json.dump(
|
"fingerprints": self.fingerprints,
|
||||||
{
|
"last_used": self.current_fingerprint,
|
||||||
"fingerprints": self.fingerprints,
|
"names": self.names,
|
||||||
"last_used": self.current_fingerprint,
|
}
|
||||||
"names": self.names,
|
atomic_write(
|
||||||
},
|
self.fingerprints_file,
|
||||||
f,
|
lambda f: json.dump(data, f, indent=4),
|
||||||
indent=4,
|
)
|
||||||
)
|
|
||||||
logger.debug(
|
logger.debug(
|
||||||
f"Fingerprints saved: {self.fingerprints} (last used: {self.current_fingerprint})"
|
f"Fingerprints saved: {self.fingerprints} (last used: {self.current_fingerprint})"
|
||||||
)
|
)
|
||||||
|
Reference in New Issue
Block a user