flush writes for concurrency safety

This commit is contained in:
thePR0M3TH3AN
2025-07-13 18:29:47 -04:00
parent 6fe4b86a19
commit 0e7c3e8a84

View File

@@ -8,7 +8,9 @@ try:
JSONDecodeError = orjson.JSONDecodeError JSONDecodeError = orjson.JSONDecodeError
USE_ORJSON = True USE_ORJSON = True
except Exception: # pragma: no cover - fallback for environments without orjson except (
Exception
): # pragma: no cover - fallback for environments without orjson
import json as json_lib import json as json_lib
from json import JSONDecodeError from json import JSONDecodeError
@@ -58,7 +60,9 @@ class EncryptionManager:
self.key = base64.urlsafe_b64decode(self.key_b64) self.key = base64.urlsafe_b64decode(self.key_b64)
self.cipher = AESGCM(self.key) self.cipher = AESGCM(self.key)
logger.debug(f"EncryptionManager initialized for {self.fingerprint_dir}") logger.debug(
f"EncryptionManager initialized for {self.fingerprint_dir}"
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to initialize ciphers with provided encryption key: {e}", f"Failed to initialize ciphers with provided encryption key: {e}",
@@ -91,12 +95,16 @@ class EncryptionManager:
ciphertext = encrypted_data[15:] ciphertext = encrypted_data[15:]
return self.cipher.decrypt(nonce, ciphertext, None) return self.cipher.decrypt(nonce, ciphertext, None)
except InvalidTag as e: except InvalidTag as e:
logger.error("AES-GCM decryption failed: Invalid authentication tag.") logger.error(
"AES-GCM decryption failed: Invalid authentication tag."
)
raise InvalidToken("AES-GCM decryption failed.") from e raise InvalidToken("AES-GCM decryption failed.") from e
# If it's not V2, it must be the legacy Fernet format # If it's not V2, it must be the legacy Fernet format
else: else:
logger.warning("Data is in legacy Fernet format. Attempting migration.") logger.warning(
"Data is in legacy Fernet format. Attempting migration."
)
try: try:
return self.fernet.decrypt(encrypted_data) return self.fernet.decrypt(encrypted_data)
except InvalidToken as e: except InvalidToken as e:
@@ -118,7 +126,9 @@ class EncryptionManager:
fh.truncate() fh.truncate()
fh.write(encrypted_data) fh.write(encrypted_data)
os.chmod(self.parent_seed_file, 0o600) os.chmod(self.parent_seed_file, 0o600)
logger.info(f"Parent seed encrypted and saved to '{self.parent_seed_file}'.") logger.info(
f"Parent seed encrypted and saved to '{self.parent_seed_file}'."
)
def decrypt_parent_seed(self) -> str: def decrypt_parent_seed(self) -> str:
"""Decrypts and returns the parent seed, handling migration.""" """Decrypts and returns the parent seed, handling migration."""
@@ -130,7 +140,9 @@ class EncryptionManager:
decrypted_data = self.decrypt_data(encrypted_data) decrypted_data = self.decrypt_data(encrypted_data)
if is_legacy: if is_legacy:
logger.info("Parent seed was in legacy format. Re-encrypting to V2 format.") logger.info(
"Parent seed was in legacy format. Re-encrypting to V2 format."
)
self.encrypt_parent_seed(decrypted_data.decode("utf-8").strip()) self.encrypt_parent_seed(decrypted_data.decode("utf-8").strip())
return decrypted_data.decode("utf-8").strip() return decrypted_data.decode("utf-8").strip()
@@ -143,6 +155,8 @@ class EncryptionManager:
fh.seek(0) fh.seek(0)
fh.truncate() fh.truncate()
fh.write(encrypted_data) fh.write(encrypted_data)
fh.flush()
os.fsync(fh.fileno())
os.chmod(file_path, 0o600) os.chmod(file_path, 0o600)
def decrypt_file(self, relative_path: Path) -> bytes: def decrypt_file(self, relative_path: Path) -> bytes:
@@ -152,13 +166,17 @@ class EncryptionManager:
encrypted_data = fh.read() encrypted_data = fh.read()
return self.decrypt_data(encrypted_data) return self.decrypt_data(encrypted_data)
def save_json_data(self, data: dict, relative_path: Optional[Path] = None) -> None: def save_json_data(
self, data: dict, relative_path: Optional[Path] = None
) -> None:
if relative_path is None: if relative_path is None:
relative_path = Path("seedpass_entries_db.json.enc") relative_path = Path("seedpass_entries_db.json.enc")
if USE_ORJSON: if USE_ORJSON:
json_data = json_lib.dumps(data) json_data = json_lib.dumps(data)
else: else:
json_data = json_lib.dumps(data, separators=(",", ":")).encode("utf-8") json_data = json_lib.dumps(data, separators=(",", ":")).encode(
"utf-8"
)
self.encrypt_and_save_file(json_data, relative_path) self.encrypt_and_save_file(json_data, relative_path)
logger.debug(f"JSON data encrypted and saved to '{relative_path}'.") logger.debug(f"JSON data encrypted and saved to '{relative_path}'.")
@@ -189,7 +207,9 @@ class EncryptionManager:
# If it was a legacy file, re-save it in the new format now # If it was a legacy file, re-save it in the new format now
if is_legacy: if is_legacy:
logger.info(f"Migrating and re-saving legacy vault file: {file_path}") logger.info(
f"Migrating and re-saving legacy vault file: {file_path}"
)
self.save_json_data(data, relative_path) self.save_json_data(data, relative_path)
self.update_checksum(relative_path) self.update_checksum(relative_path)
@@ -224,17 +244,25 @@ class EncryptionManager:
data = json_lib.loads(decrypted_data) data = json_lib.loads(decrypted_data)
else: else:
data = json_lib.loads(decrypted_data.decode("utf-8")) data = json_lib.loads(decrypted_data.decode("utf-8"))
self.save_json_data(data, relative_path) # This always saves in V2 format self.save_json_data(
data, relative_path
) # This always saves in V2 format
self.update_checksum(relative_path) self.update_checksum(relative_path)
logger.info("Index file from Nostr was processed and saved successfully.") logger.info(
print(colored("Index file updated from Nostr successfully.", "green")) "Index file from Nostr was processed and saved successfully."
)
print(
colored("Index file updated from Nostr successfully.", "green")
)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to decrypt and save data from Nostr: {e}", exc_info=True f"Failed to decrypt and save data from Nostr: {e}",
exc_info=True,
) )
print( print(
colored( colored(
f"Error: Failed to decrypt and save data from Nostr: {e}", "red" f"Error: Failed to decrypt and save data from Nostr: {e}",
"red",
) )
) )
raise raise
@@ -258,10 +286,13 @@ class EncryptionManager:
fh.seek(0) fh.seek(0)
fh.truncate() fh.truncate()
fh.write(checksum.encode("utf-8")) fh.write(checksum.encode("utf-8"))
fh.flush()
os.fsync(fh.fileno())
os.chmod(checksum_file, 0o600) os.chmod(checksum_file, 0o600)
except Exception as e: except Exception as e:
logger.error( logger.error(
f"Failed to update checksum for '{relative_path}': {e}", exc_info=True f"Failed to update checksum for '{relative_path}': {e}",
exc_info=True,
) )
raise raise
@@ -272,17 +303,24 @@ class EncryptionManager:
if len(words) != 12: if len(words) != 12:
logger.error("Seed phrase does not contain exactly 12 words.") logger.error("Seed phrase does not contain exactly 12 words.")
print( print(
colored("Error: Seed phrase must contain exactly 12 words.", "red") colored(
"Error: Seed phrase must contain exactly 12 words.",
"red",
)
) )
return False return False
logger.debug("Seed phrase validated successfully.") logger.debug("Seed phrase validated successfully.")
return True return True
except Exception as e: except Exception as e:
logging.error(f"Error validating seed phrase: {e}", exc_info=True) logging.error(f"Error validating seed phrase: {e}", exc_info=True)
print(colored(f"Error: Failed to validate seed phrase: {e}", "red")) print(
colored(f"Error: Failed to validate seed phrase: {e}", "red")
)
return False return False
def derive_seed_from_mnemonic(self, mnemonic: str, passphrase: str = "") -> bytes: def derive_seed_from_mnemonic(
self, mnemonic: str, passphrase: str = ""
) -> bytes:
try: try:
if not isinstance(mnemonic, str): if not isinstance(mnemonic, str):
if isinstance(mnemonic, list): if isinstance(mnemonic, list):
@@ -290,13 +328,21 @@ class EncryptionManager:
else: else:
mnemonic = str(mnemonic) mnemonic = str(mnemonic)
if not isinstance(mnemonic, str): if not isinstance(mnemonic, str):
raise TypeError("Mnemonic must be a string after conversion") raise TypeError(
"Mnemonic must be a string after conversion"
)
from bip_utils import Bip39SeedGenerator from bip_utils import Bip39SeedGenerator
seed = Bip39SeedGenerator(mnemonic).Generate(passphrase) seed = Bip39SeedGenerator(mnemonic).Generate(passphrase)
logger.debug("Seed derived successfully from mnemonic.") logger.debug("Seed derived successfully from mnemonic.")
return seed return seed
except Exception as e: except Exception as e:
logger.error(f"Failed to derive seed from mnemonic: {e}", exc_info=True) logger.error(
print(colored(f"Error: Failed to derive seed from mnemonic: {e}", "red")) f"Failed to derive seed from mnemonic: {e}", exc_info=True
)
print(
colored(
f"Error: Failed to derive seed from mnemonic: {e}", "red"
)
)
raise raise