diff --git a/src/tests/test_seed_prompt.py b/src/tests/test_seed_prompt.py index 3711cd4..163b7ff 100644 --- a/src/tests/test_seed_prompt.py +++ b/src/tests/test_seed_prompt.py @@ -1,4 +1,5 @@ import types +import pytest from utils import seed_prompt @@ -46,6 +47,37 @@ def test_masked_input_windows_space(monkeypatch, capsys): assert out.count("*") == 4 +def test_masked_input_posix_ctrl_c(monkeypatch): + seq = iter(["\x03"]) + monkeypatch.setattr(seed_prompt.sys.stdin, "read", lambda n=1: next(seq)) + monkeypatch.setattr(seed_prompt.sys.stdin, "fileno", lambda: 0) + + calls: list[tuple[str, int]] = [] + fake_termios = types.SimpleNamespace( + tcgetattr=lambda fd: "old", + tcsetattr=lambda fd, *_: calls.append(("tcsetattr", fd)), + TCSADRAIN=1, + ) + fake_tty = types.SimpleNamespace(setraw=lambda fd: calls.append(("setraw", fd))) + monkeypatch.setattr(seed_prompt, "termios", fake_termios) + monkeypatch.setattr(seed_prompt, "tty", fake_tty) + monkeypatch.setattr(seed_prompt.sys, "platform", "linux", raising=False) + + with pytest.raises(KeyboardInterrupt): + seed_prompt.masked_input("Enter: ") + assert calls == [("setraw", 0), ("tcsetattr", 0)] + + +def test_masked_input_windows_ctrl_c(monkeypatch): + seq = iter(["\x03"]) + fake_msvcrt = types.SimpleNamespace(getwch=lambda: next(seq)) + monkeypatch.setattr(seed_prompt, "msvcrt", fake_msvcrt) + monkeypatch.setattr(seed_prompt.sys, "platform", "win32", raising=False) + + with pytest.raises(KeyboardInterrupt): + seed_prompt.masked_input("Password: ") + + def test_prompt_seed_words_valid(monkeypatch): from mnemonic import Mnemonic diff --git a/src/utils/seed_prompt.py b/src/utils/seed_prompt.py index 7ff0130..7511488 100644 --- a/src/utils/seed_prompt.py +++ b/src/utils/seed_prompt.py @@ -58,6 +58,8 @@ def _masked_input_windows(prompt: str) -> str: buffer: list[str] = [] while True: ch = msvcrt.getwch() + if ch == "\x03": + raise KeyboardInterrupt if ch in ("\r", "\n"): sys.stdout.write("\n") return "".join(buffer) @@ -85,6 +87,8 @@ def _masked_input_posix(prompt: str) -> str: tty.setraw(fd) while True: ch = sys.stdin.read(1) + if ch == "\x03": + raise KeyboardInterrupt if ch in ("\r", "\n"): sys.stdout.write("\n") return "".join(buffer) @@ -105,6 +109,8 @@ def masked_input(prompt: str) -> str: func = _masked_input_windows if sys.platform == "win32" else _masked_input_posix try: return func(prompt) + except KeyboardInterrupt: + raise except Exception: # pragma: no cover - fallback when TTY operations fail return input(prompt)