Handle EntryType objects when loading

This commit is contained in:
thePR0M3TH3AN
2025-07-15 12:47:01 -04:00
parent 0bfc641815
commit 5a3b80b4f6
2 changed files with 56 additions and 24 deletions

View File

@@ -1837,6 +1837,13 @@ class PasswordManager:
self.is_dirty = True
self.last_update = time.time()
def _entry_type_str(self, entry: dict) -> str:
"""Return the entry type as a lowercase string."""
entry_type = entry.get("type", entry.get("kind", EntryType.PASSWORD.value))
if isinstance(entry_type, EntryType):
entry_type = entry_type.value
return str(entry_type).lower()
def _entry_actions_menu(self, index: int, entry: dict) -> None:
"""Provide actions for a retrieved entry."""
while True:
@@ -1849,9 +1856,7 @@ class PasswordManager:
child_fingerprint=child_fp,
)
archived = entry.get("archived", entry.get("blacklisted", False))
entry_type = entry.get("type", entry.get("kind", EntryType.PASSWORD.value))
if isinstance(entry_type, str):
entry_type = entry_type.lower()
entry_type = self._entry_type_str(entry)
print(colored("\n[+] Entry Actions:", "green"))
if archived:
print(colored("U. Unarchive", "cyan"))
@@ -1934,9 +1939,7 @@ class PasswordManager:
def _entry_edit_menu(self, index: int, entry: dict) -> None:
"""Sub-menu for editing common entry fields."""
entry_type = entry.get("type", entry.get("kind", EntryType.PASSWORD.value))
if isinstance(entry_type, str):
entry_type = entry_type.lower()
entry_type = self._entry_type_str(entry)
while True:
fp, parent_fp, child_fp = self.header_fingerprint_args
clear_header_with_notification(
@@ -1996,9 +1999,7 @@ class PasswordManager:
def _entry_qr_menu(self, index: int, entry: dict) -> None:
"""Display QR codes for the given ``entry``."""
entry_type = entry.get("type", entry.get("kind"))
if isinstance(entry_type, str):
entry_type = entry_type.lower()
entry_type = self._entry_type_str(entry)
try:
if entry_type in {EntryType.SEED.value, EntryType.MANAGED_ACCOUNT.value}:
@@ -2071,9 +2072,7 @@ class PasswordManager:
self._suppress_entry_actions_menu = False
entry_type = entry.get("type", entry.get("kind", EntryType.PASSWORD.value))
if isinstance(entry_type, str):
entry_type = entry_type.lower()
entry_type = self._entry_type_str(entry)
if entry_type == EntryType.TOTP.value:
label = entry.get("label", "")
@@ -2533,9 +2532,7 @@ class PasswordManager:
if not entry:
return
entry_type = entry.get("type", entry.get("kind", EntryType.PASSWORD.value))
if isinstance(entry_type, str):
entry_type = entry_type.lower()
entry_type = self._entry_type_str(entry)
if entry_type == EntryType.TOTP.value:
label = entry.get("label", "")
@@ -2928,10 +2925,7 @@ class PasswordManager:
if not entry:
return
etype = entry.get("type", entry.get("kind", EntryType.PASSWORD.value))
if isinstance(etype, EntryType):
etype = etype.value
etype = str(etype).lower()
etype = self._entry_type_str(entry)
print(color_text(f"Index: {index}", "index"))
if etype == EntryType.TOTP.value:
print(color_text(f" Label: {entry.get('label', '')}", "index"))
@@ -3286,7 +3280,9 @@ class PasswordManager:
entries = data.get("entries", {})
totp_list: list[tuple[str, int, int, bool]] = []
for idx_str, entry in entries.items():
if entry.get("type") == EntryType.TOTP.value and not entry.get(
if self._entry_type_str(
entry
) == EntryType.TOTP.value and not entry.get(
"archived", entry.get("blacklisted", False)
):
label = entry.get("label", "")
@@ -3582,7 +3578,7 @@ class PasswordManager:
totp_entries = []
for entry in entries.values():
if entry.get("type") == EntryType.TOTP.value:
if self._entry_type_str(entry) == EntryType.TOTP.value:
label = entry.get("label", "")
period = int(entry.get("period", 30))
digits = int(entry.get("digits", 6))
@@ -3878,9 +3874,7 @@ class PasswordManager:
entries = data.get("entries", {})
counts: dict[str, int] = {etype.value: 0 for etype in EntryType}
for entry in entries.values():
etype = entry.get("type", entry.get("kind", EntryType.PASSWORD.value))
if isinstance(etype, str):
etype = etype.lower()
etype = self._entry_type_str(entry)
counts[etype] = counts.get(etype, 0) + 1
stats["entries"] = counts
stats["total_entries"] = len(entries)

View File

@@ -13,6 +13,7 @@ sys.path.append(str(Path(__file__).resolve().parents[1]))
from password_manager.entry_management import EntryManager
from password_manager.backup import BackupManager
from password_manager.manager import PasswordManager, EncryptionMode
from password_manager.entry_types import EntryType
from password_manager.config_manager import ConfigManager
@@ -356,3 +357,40 @@ def test_show_entry_details_sensitive(monkeypatch, capsys, entry_type):
if entry_type in {"ssh", "pgp"}:
assert extra in out
assert called == [True]
@pytest.mark.parametrize(
"entry_type", [EntryType.PASSWORD, EntryType.TOTP, EntryType.KEY_VALUE]
)
def test_show_entry_details_with_enum_type(monkeypatch, capsys, entry_type):
"""Entries storing an EntryType enum should display correctly."""
with TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
pm, entry_mgr = _setup_manager(tmp_path)
if entry_type == EntryType.PASSWORD:
idx = entry_mgr.add_entry("example.com", 8)
expect = "example.com"
elif entry_type == EntryType.TOTP:
entry_mgr.add_totp("Example", TEST_SEED)
idx = 0
monkeypatch.setattr(
pm.entry_manager, "get_totp_code", lambda *a, **k: "123456"
)
monkeypatch.setattr(
pm.entry_manager, "get_totp_time_remaining", lambda *a, **k: 1
)
expect = "Label: Example"
else: # KEY_VALUE
idx = entry_mgr.add_key_value("API", "abc")
expect = "API"
data = entry_mgr._load_index(force_reload=True)
data["entries"][str(idx)]["type"] = entry_type
entry_mgr._save_index(data)
called = _detail_common(monkeypatch, pm)
pm.show_entry_details_by_index(idx)
out = capsys.readouterr().out
assert expect in out
assert called == [True]