update generate_step callsites

This commit is contained in:
L Lllvvuu
2024-12-27 01:51:28 -08:00
parent 3082db0143
commit a28ca03e04
3 changed files with 26 additions and 15 deletions

View File

@@ -121,21 +121,24 @@ class TestPromptCache(unittest.TestCase):
def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = list(generate_step(prompt, model, max_tokens=4))
results = list(generate_step(prompt[None], model, max_tokens=4))
results = [(t.item(), l.squeeze(0)) for t, l in results]
toks, all_logits = zip(*results)
prompt_cache = make_prompt_cache(model)
i = 0
for tok, logits in generate_step(
prompt, model, prompt_cache=prompt_cache, max_tokens=2
prompt[None], model, prompt_cache=prompt_cache, max_tokens=2
):
tok, logits = tok.item(), logits.squeeze(0)
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1
for tok, logits in generate_step(
mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1
mx.array([[toks[i]]]), model, prompt_cache=prompt_cache, max_tokens=1
):
tok, logits = tok.item(), logits.squeeze(0)
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
@@ -205,14 +208,14 @@ class TestPromptCache(unittest.TestCase):
prompt_cache = make_prompt_cache(model)
# Generate one token so we process the full prompt
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
last_tok = mx.array([last_tok])
last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache))
# Generate two more tokens
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache)
)
toks, all_logits = zip(*(r[1] for r in results))
results = [(t.item(), l.squeeze(0)) for _, (t, l) in results]
toks, all_logits = zip(*results)
# To get back to the cache just after processing the prompt,
# trim by 3 tokens
@@ -220,9 +223,10 @@ class TestPromptCache(unittest.TestCase):
# Generate the same thing again
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache)
)
second_toks, second_all_logits = zip(*(r[1] for r in results))
results = [(t.item(), l.squeeze(0)) for _, (t, l) in results]
second_toks, second_all_logits = zip(*results)
self.assertEqual(toks, second_toks)
self.assertTrue(
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
@@ -278,14 +282,16 @@ class TestPromptCache(unittest.TestCase):
def test_cache_to_quantized(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
results = zip(range(4), generate_step(prompt[None], model))
results = [(t.item(), l.squeeze(0)) for _, (t, l) in results]
toks, all_logits = zip(*results)
prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
range(2), generate_step(prompt[None], model, prompt_cache=prompt_cache)
):
tok, logits = tok.item(), logits.squeeze(0)
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1
@@ -294,8 +300,9 @@ class TestPromptCache(unittest.TestCase):
for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
generate_step(mx.array([[toks[i]]]), model, prompt_cache=prompt_cache),
):
tok, logits = tok.item(), logits.squeeze(0)
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))