diff --git a/src/tests/test_seed_prompt.py b/src/tests/test_seed_prompt.py index c030b69..700c03f 100644 --- a/src/tests/test_seed_prompt.py +++ b/src/tests/test_seed_prompt.py @@ -37,9 +37,9 @@ def test_prompt_seed_words_valid(monkeypatch): phrase = m.generate(strength=128) words = phrase.split() - inputs = iter(words + ["y"] * len(words)) - monkeypatch.setattr(seed_prompt, "masked_input", lambda *_: next(inputs)) - monkeypatch.setattr("builtins.input", lambda *_: next(inputs)) + word_iter = iter(words) + monkeypatch.setattr(seed_prompt, "masked_input", lambda *_: next(word_iter)) + monkeypatch.setattr("builtins.input", lambda *_: "y") result = seed_prompt.prompt_seed_words(len(words)) assert result == phrase @@ -52,9 +52,9 @@ def test_prompt_seed_words_invalid_word(monkeypatch): phrase = m.generate(strength=128) words = phrase.split() # Insert an invalid word for the first entry then the correct one - inputs = iter(["invalid"] + [words[0]] + words[1:] + ["y"] * len(words)) + inputs = iter(["invalid"] + words) monkeypatch.setattr(seed_prompt, "masked_input", lambda *_: next(inputs)) - monkeypatch.setattr("builtins.input", lambda *_: next(inputs)) + monkeypatch.setattr("builtins.input", lambda *_: "y") result = seed_prompt.prompt_seed_words(len(words)) assert result == phrase