From a2a3236248e453d1a403a3a800803ebd1a99fcc8 Mon Sep 17 00:00:00 2001 From: thePR0M3TH3AN <53631862+PR0M3TH3AN@users.noreply.github.com> Date: Sun, 29 Jun 2025 22:58:43 -0400 Subject: [PATCH] Add shared file lock and concurrency test --- src/tests/test_file_locking.py | 42 ++++++++++++++++++++++++++++++++++ src/utils/__init__.py | 3 ++- src/utils/file_lock.py | 22 ++++++++++++++++++ 3 files changed, 66 insertions(+), 1 deletion(-) create mode 100644 src/tests/test_file_locking.py diff --git a/src/tests/test_file_locking.py b/src/tests/test_file_locking.py new file mode 100644 index 0000000..1fc0e18 --- /dev/null +++ b/src/tests/test_file_locking.py @@ -0,0 +1,42 @@ +import threading +from pathlib import Path + +from utils.file_lock import exclusive_lock, shared_lock + + +def _writer(path: Path, content: str, exceptions: list[str]) -> None: + try: + with exclusive_lock(path): + path.write_text(content) + except Exception as e: # pragma: no cover - just capture + exceptions.append(repr(e)) + + +def _reader(path: Path, results: list[str], exceptions: list[str]) -> None: + try: + with shared_lock(path): + results.append(path.read_text()) + except Exception as e: # pragma: no cover + exceptions.append(repr(e)) + + +def test_concurrent_shared_and_exclusive_lock(tmp_path: Path) -> None: + file_path = tmp_path / "data.txt" + file_path.write_text("init") + + exceptions: list[str] = [] + reads: list[str] = [] + for i in range(5): + writer = threading.Thread( + target=_writer, args=(file_path, f"value{i}", exceptions) + ) + reader = threading.Thread(target=_reader, args=(file_path, reads, exceptions)) + + writer.start() + reader.start() + writer.join() + reader.join() + + assert not exceptions + assert file_path.read_text() == "value4" + assert len(reads) == 5 diff --git a/src/utils/__init__.py b/src/utils/__init__.py index 3329ebf..e12cbff 100644 --- a/src/utils/__init__.py +++ b/src/utils/__init__.py @@ -4,7 +4,7 @@ import logging import traceback try: - from .file_lock import exclusive_lock + from .file_lock import exclusive_lock, shared_lock from .key_derivation import derive_key_from_password, derive_key_from_parent_seed from .checksum import calculate_checksum, verify_checksum from .password_prompt import prompt_for_password @@ -20,5 +20,6 @@ __all__ = [ "calculate_checksum", "verify_checksum", "exclusive_lock", + "shared_lock", "prompt_for_password", ] diff --git a/src/utils/file_lock.py b/src/utils/file_lock.py index 9c71b62..4d674f2 100644 --- a/src/utils/file_lock.py +++ b/src/utils/file_lock.py @@ -22,3 +22,25 @@ def exclusive_lock( lock = portalocker.Lock(str(path), mode="a+b", timeout=timeout) with lock as fh: yield fh + + +@contextmanager +def shared_lock( + path: Path, timeout: Optional[float] = None +) -> Generator[None, None, None]: + """Context manager that locks *path* with a shared lock. + + The function opens the file in binary read/write mode and obtains a + shared lock using ``portalocker``. If ``timeout`` is provided, acquiring + the lock will wait for at most that many seconds before raising + ``portalocker.exceptions.LockException``. + """ + path = Path(path) + path.parent.mkdir(parents=True, exist_ok=True) + path.touch(exist_ok=True) + lock = portalocker.Lock( + str(path), mode="r+b", timeout=timeout, flags=portalocker.LockFlags.SHARED + ) + with lock as fh: + fh.seek(0) + yield fh