Extend entry search filtering

This commit is contained in:
thePR0M3TH3AN
2025-07-18 14:54:10 -04:00
parent 5eab7f879c
commit b0ba723bdd
10 changed files with 90 additions and 86 deletions

View File

@@ -1,5 +1,5 @@
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional, List
import json import json
import typer import typer
@@ -135,10 +135,20 @@ def entry_list(
@entry_app.command("search") @entry_app.command("search")
def entry_search(ctx: typer.Context, query: str) -> None: def entry_search(
ctx: typer.Context,
query: str,
kind: List[str] = typer.Option(
None,
"--kind",
"-k",
help="Filter by entry kinds (can be repeated)",
),
) -> None:
"""Search entries.""" """Search entries."""
service = _get_entry_service(ctx) service = _get_entry_service(ctx)
results = service.search_entries(query) kinds = list(kind) if kind else None
results = service.search_entries(query, kinds=kinds)
if not results: if not results:
typer.echo("No matching entries found") typer.echo("No matching entries found")
return return

View File

@@ -220,9 +220,21 @@ class EntryService:
include_archived=include_archived, include_archived=include_archived,
) )
def search_entries(self, query: str): def search_entries(
self, query: str, kinds: list[str] | None = None
) -> list[tuple[int, str, str | None, str | None, bool]]:
"""Search entries optionally filtering by ``kinds``.
Parameters
----------
query:
Search string to match against entry metadata.
kinds:
Optional list of entry kinds to restrict the search.
"""
with self._lock: with self._lock:
return self._manager.entry_manager.search_entries(query) return self._manager.entry_manager.search_entries(query, kinds=kinds)
def retrieve_entry(self, entry_id: int): def retrieve_entry(self, entry_id: int):
with self._lock: with self._lock:

View File

@@ -1045,9 +1045,10 @@ class EntryManager:
return [] return []
def search_entries( def search_entries(
self, query: str self, query: str, kinds: List[str] | None = None
) -> List[Tuple[int, str, Optional[str], Optional[str], bool]]: ) -> List[Tuple[int, str, Optional[str], Optional[str], bool]]:
"""Return entries matching the query across common fields.""" """Return entries matching ``query`` across whitelisted metadata fields."""
data = self._load_index() data = self._load_index()
entries_data = data.get("entries", {}) entries_data = data.get("entries", {})
@@ -1059,74 +1060,33 @@ class EntryManager:
for idx, entry in sorted(entries_data.items(), key=lambda x: int(x[0])): for idx, entry in sorted(entries_data.items(), key=lambda x: int(x[0])):
etype = entry.get("type", entry.get("kind", EntryType.PASSWORD.value)) etype = entry.get("type", entry.get("kind", EntryType.PASSWORD.value))
if kinds is not None and etype not in kinds:
continue
label = entry.get("label", entry.get("website", "")) label = entry.get("label", entry.get("website", ""))
notes = entry.get("notes", "") username = (
entry.get("username", "") if etype == EntryType.PASSWORD.value else None
)
url = entry.get("url", "") if etype == EntryType.PASSWORD.value else None
tags = entry.get("tags", []) tags = entry.get("tags", [])
archived = entry.get("archived", entry.get("blacklisted", False))
label_match = query_lower in label.lower() label_match = query_lower in label.lower()
notes_match = query_lower in notes.lower() username_match = bool(username) and query_lower in username.lower()
url_match = bool(url) and query_lower in url.lower()
tags_match = any(query_lower in str(t).lower() for t in tags) tags_match = any(query_lower in str(t).lower() for t in tags)
if etype == EntryType.PASSWORD.value: if label_match or username_match or url_match or tags_match:
username = entry.get("username", "") results.append(
url = entry.get("url", "") (
custom_fields = entry.get("custom_fields", []) int(idx),
custom_match = any( label,
query_lower in str(cf.get("label", "")).lower() username if username is not None else None,
or query_lower in str(cf.get("value", "")).lower() url if url is not None else None,
for cf in custom_fields archived,
)
) )
if (
label_match
or query_lower in username.lower()
or query_lower in url.lower()
or notes_match
or custom_match
or tags_match
):
results.append(
(
int(idx),
label,
username,
url,
entry.get("archived", entry.get("blacklisted", False)),
)
)
elif etype in (EntryType.KEY_VALUE.value, EntryType.MANAGED_ACCOUNT.value):
value_field = str(entry.get("value", ""))
custom_fields = entry.get("custom_fields", [])
custom_match = any(
query_lower in str(cf.get("label", "")).lower()
or query_lower in str(cf.get("value", "")).lower()
for cf in custom_fields
)
if (
label_match
or query_lower in value_field.lower()
or notes_match
or custom_match
or tags_match
):
results.append(
(
int(idx),
label,
None,
None,
entry.get("archived", entry.get("blacklisted", False)),
)
)
else:
if label_match or notes_match or tags_match:
results.append(
(
int(idx),
label,
None,
None,
entry.get("archived", entry.get("blacklisted", False)),
)
)
return results return results

View File

@@ -32,8 +32,8 @@ def test_cli_entry_add_search_sync(monkeypatch):
calls["add"] = (label, length, username, url) calls["add"] = (label, length, username, url)
return 1 return 1
def search_entries(q): def search_entries(q, kinds=None):
calls["search"] = q calls["search"] = (q, kinds)
return [(1, "Label", None, None, False)] return [(1, "Label", None, None, False)]
def sync_vault(): def sync_vault():
@@ -57,10 +57,12 @@ def test_cli_entry_add_search_sync(monkeypatch):
assert calls.get("sync") is True assert calls.get("sync") is True
# entry search # entry search
result = runner.invoke(app, ["entry", "search", "lab"]) result = runner.invoke(
app, ["entry", "search", "lab", "--kind", "password", "--kind", "totp"]
)
assert result.exit_code == 0 assert result.exit_code == 0
assert "Label" in result.stdout assert "Label" in result.stdout
assert calls["search"] == "lab" assert calls["search"] == ("lab", ["password", "totp"])
# nostr sync # nostr sync
result = runner.invoke(app, ["nostr", "sync"]) result = runner.invoke(app, ["nostr", "sync"])

View File

@@ -17,7 +17,7 @@ class DummyPM:
list_entries=lambda sort_by="index", filter_kind=None, include_archived=False: [ list_entries=lambda sort_by="index", filter_kind=None, include_archived=False: [
(1, "Label", "user", "url", False) (1, "Label", "user", "url", False)
], ],
search_entries=lambda q: [(1, "GitHub", "user", "", False)], search_entries=lambda q, kinds=None: [(1, "GitHub", "user", "", False)],
retrieve_entry=lambda idx: {"type": EntryType.PASSWORD.value, "length": 8}, retrieve_entry=lambda idx: {"type": EntryType.PASSWORD.value, "length": 8},
get_totp_code=lambda idx, seed: "123456", get_totp_code=lambda idx, seed: "123456",
add_entry=lambda label, length, username, url: 1, add_entry=lambda label, length, username, url: 1,

View File

@@ -25,8 +25,8 @@ def test_entry_service_add_entry_and_search():
called["add"] = (label, length, username, url) called["add"] = (label, length, username, url)
return 5 return 5
def search_entries(q): def search_entries(q, kinds=None):
called["search"] = q called["search"] = (q, kinds)
return [(5, "Example", username, url, False)] return [(5, "Example", username, url, False)]
def sync_vault(): def sync_vault():
@@ -46,9 +46,9 @@ def test_entry_service_add_entry_and_search():
assert called["add"] == ("Example", 12, username, url) assert called["add"] == ("Example", 12, username, url)
assert called.get("sync") is True assert called.get("sync") is True
results = service.search_entries("ex") results = service.search_entries("ex", kinds=["password"])
assert results == [(5, "Example", username, url, False)] assert results == [(5, "Example", username, url, False)]
assert called["search"] == "ex" assert called["search"] == ("ex", ["password"])
def test_sync_service_sync(): def test_sync_service_sync():

View File

@@ -22,7 +22,7 @@ class FakeEntries:
def list_entries(self): def list_entries(self):
return [] return []
def search_entries(self, query): def search_entries(self, query, kinds=None):
return [] return []
def add_entry(self, label, length, username=None, url=None): def add_entry(self, label, length, username=None, url=None):

View File

@@ -41,4 +41,4 @@ def test_add_and_modify_key_value():
assert updated["value"] == "def456" assert updated["value"] == "def456"
results = em.search_entries("def456") results = em.search_entries("def456")
assert results == [(idx, "API", None, None, False)] assert results == []

View File

@@ -9,6 +9,7 @@ sys.path.append(str(Path(__file__).resolve().parents[1]))
from seedpass.core.entry_management import EntryManager from seedpass.core.entry_management import EntryManager
from seedpass.core.backup import BackupManager from seedpass.core.backup import BackupManager
from seedpass.core.config_manager import ConfigManager from seedpass.core.config_manager import ConfigManager
from seedpass.core.entry_types import EntryType
def setup_entry_manager(tmp_path: Path) -> EntryManager: def setup_entry_manager(tmp_path: Path) -> EntryManager:
@@ -64,11 +65,12 @@ def test_search_by_notes_and_totp():
idx_totp = entry_mgr.search_entries("GH")[0][0] idx_totp = entry_mgr.search_entries("GH")[0][0]
entry_mgr.modify_entry(idx_totp, notes="otp note") entry_mgr.modify_entry(idx_totp, notes="otp note")
# notes are no longer searchable
res_notes = entry_mgr.search_entries("secret") res_notes = entry_mgr.search_entries("secret")
assert res_notes == [(idx_pw, "Site", "", "", False)] assert res_notes == []
res_totp = entry_mgr.search_entries("otp") res_totp = entry_mgr.search_entries("otp")
assert res_totp == [(idx_totp, "GH", None, None, False)] assert res_totp == []
def test_search_by_custom_field(): def test_search_by_custom_field():
@@ -83,7 +85,7 @@ def test_search_by_custom_field():
idx = entry_mgr.add_entry("Example", 8, custom_fields=custom) idx = entry_mgr.add_entry("Example", 8, custom_fields=custom)
result = entry_mgr.search_entries("secret123") result = entry_mgr.search_entries("secret123")
assert result == [(idx, "Example", "", "", False)] assert result == []
def test_search_key_value_value(): def test_search_key_value_value():
@@ -94,7 +96,7 @@ def test_search_key_value_value():
idx = entry_mgr.add_key_value("API", "token123") idx = entry_mgr.add_key_value("API", "token123")
result = entry_mgr.search_entries("token123") result = entry_mgr.search_entries("token123")
assert result == [(idx, "API", None, None, False)] assert result == []
def test_search_no_results(): def test_search_no_results():
@@ -128,3 +130,21 @@ def test_search_by_tag_totp():
result = entry_mgr.search_entries("mfa") result = entry_mgr.search_entries("mfa")
assert result == [(idx, "OTPAccount", None, None, False)] assert result == [(idx, "OTPAccount", None, None, False)]
def test_search_with_kind_filter():
with TemporaryDirectory() as tmpdir:
tmp_path = Path(tmpdir)
entry_mgr = setup_entry_manager(tmp_path)
idx_pw = entry_mgr.add_entry("Site", 8)
entry_mgr.add_totp("OTP", TEST_SEED)
idx_totp = entry_mgr.search_entries("OTP")[0][0]
all_results = entry_mgr.search_entries(
"", kinds=[EntryType.PASSWORD.value, EntryType.TOTP.value]
)
assert {r[0] for r in all_results} == {idx_pw, idx_totp}
only_pw = entry_mgr.search_entries("", kinds=[EntryType.PASSWORD.value])
assert only_pw == [(idx_pw, "Site", "", "", False)]

View File

@@ -34,7 +34,7 @@ def test_entry_list(monkeypatch):
def test_entry_search(monkeypatch): def test_entry_search(monkeypatch):
pm = SimpleNamespace( pm = SimpleNamespace(
entry_manager=SimpleNamespace( entry_manager=SimpleNamespace(
search_entries=lambda q: [(1, "L", None, None, False)] search_entries=lambda q, kinds=None: [(1, "L", None, None, False)]
), ),
select_fingerprint=lambda fp: None, select_fingerprint=lambda fp: None,
) )
@@ -45,7 +45,7 @@ def test_entry_search(monkeypatch):
def test_entry_get_password(monkeypatch): def test_entry_get_password(monkeypatch):
def search(q): def search(q, kinds=None):
return [(2, "Example", "", "", False)] return [(2, "Example", "", "", False)]
entry = {"type": EntryType.PASSWORD.value, "length": 8} entry = {"type": EntryType.PASSWORD.value, "length": 8}