Allow prompt callback to generate_step (#1133)

* allow prompt callback and use in cache_prompt

* nit

* comments

* bump version
This commit is contained in:
Awni Hannun
2024-12-03 16:17:14 -08:00
committed by GitHub
parent 0ca162cfb2
commit 1963df8565
5 changed files with 48 additions and 48 deletions

View File

@@ -121,21 +121,20 @@ class TestPromptCache(unittest.TestCase):
def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
results = list(generate_step(prompt, model, max_tokens=4))
toks, all_logits = zip(*results)
prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
for tok, logits in generate_step(
prompt, model, prompt_cache=prompt_cache, max_tokens=2
):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1
for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
for tok, logits in generate_step(
mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
):
i += 1
self.assertEqual(tok, toks[i])