mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Allow prompt callback to generate_step
(#1133)
* allow prompt callback and use in cache_prompt * nit * comments * bump version
This commit is contained in:
@@ -121,21 +121,20 @@ class TestPromptCache(unittest.TestCase):
|
||||
def test_cache_with_generate(self):
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
|
||||
results = zip(range(4), generate_step(prompt, model))
|
||||
toks, all_logits = zip(*(r[1] for r in results))
|
||||
results = list(generate_step(prompt, model, max_tokens=4))
|
||||
toks, all_logits = zip(*results)
|
||||
|
||||
prompt_cache = make_prompt_cache(model)
|
||||
i = 0
|
||||
for _, (tok, logits) in zip(
|
||||
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
|
||||
for tok, logits in generate_step(
|
||||
prompt, model, prompt_cache=prompt_cache, max_tokens=2
|
||||
):
|
||||
self.assertEqual(tok, toks[i])
|
||||
self.assertTrue(mx.allclose(logits, all_logits[i]))
|
||||
i += 1
|
||||
|
||||
for _, (tok, logits) in zip(
|
||||
range(1),
|
||||
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
|
||||
for tok, logits in generate_step(
|
||||
mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
|
||||
):
|
||||
i += 1
|
||||
self.assertEqual(tok, toks[i])
|
||||
|
Reference in New Issue
Block a user