diff --git a/llms/mlx_lm/cache_prompt.py b/llms/mlx_lm/cache_prompt.py index 9d7d1603..4f88061e 100644 --- a/llms/mlx_lm/cache_prompt.py +++ b/llms/mlx_lm/cache_prompt.py @@ -132,7 +132,7 @@ def main(): prompt = args.prompt cache = make_prompt_cache(model, args.max_kv_size) - y = mx.array(tokenizer.encode(prompt)) + y = mx.array(tokenizer.encode(prompt))[None] # Process the prompt start = time.time() diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index ec52e283..b4f7728d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -296,7 +296,7 @@ def generate_step( def _step(y): with mx.stream(generation_stream): - if y.ndims == 1: + if y.ndim == 1: y = mx.expand_dims(y, axis=-1) logits = model( y, @@ -390,12 +390,16 @@ def stream_generate( prompt if isinstance(prompt, list) else tokenizer.encode(prompt) ) + if prompt.ndim == 1: + prompt = prompt[None] + detokenizer = tokenizer.detokenizer with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): + token, logprobs = token.item(), logprobs.squeeze(0) if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time diff --git a/llms/tests/test_prompt_cache.py b/llms/tests/test_prompt_cache.py index de5694d5..6acab5a7 100644 --- a/llms/tests/test_prompt_cache.py +++ b/llms/tests/test_prompt_cache.py @@ -121,21 +121,24 @@ class TestPromptCache(unittest.TestCase): def test_cache_with_generate(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = list(generate_step(prompt, model, max_tokens=4)) + results = list(generate_step(prompt[None], model, max_tokens=4)) + results = [(t.item(), l.squeeze(0)) for t, l in results] toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 for tok, logits in generate_step( - prompt, model, prompt_cache=prompt_cache, max_tokens=2 + prompt[None], model, prompt_cache=prompt_cache, max_tokens=2 ): + tok, logits = tok.item(), logits.squeeze(0) self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 for tok, logits in generate_step( - mx.array([toks[i]]), model, prompt_cache=prompt_cache, max_tokens=1 + mx.array([[toks[i]]]), model, prompt_cache=prompt_cache, max_tokens=1 ): + tok, logits = tok.item(), logits.squeeze(0) i += 1 self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) @@ -205,14 +208,14 @@ class TestPromptCache(unittest.TestCase): prompt_cache = make_prompt_cache(model) # Generate one token so we process the full prompt - last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache)) - last_tok = mx.array([last_tok]) + last_tok, _ = next(generate_step(prompt[None], model, prompt_cache=prompt_cache)) # Generate two more tokens results = zip( - range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache) ) - toks, all_logits = zip(*(r[1] for r in results)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + toks, all_logits = zip(*results) # To get back to the cache just after processing the prompt, # trim by 3 tokens @@ -220,9 +223,10 @@ class TestPromptCache(unittest.TestCase): # Generate the same thing again results = zip( - range(2), generate_step(last_tok, model, prompt_cache=prompt_cache) + range(2), generate_step(last_tok[None], model, prompt_cache=prompt_cache) ) - second_toks, second_all_logits = zip(*(r[1] for r in results)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + second_toks, second_all_logits = zip(*results) self.assertEqual(toks, second_toks) self.assertTrue( all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits)) @@ -278,14 +282,16 @@ class TestPromptCache(unittest.TestCase): def test_cache_to_quantized(self): model, tokenizer = load(HF_MODEL_PATH) prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0] - results = zip(range(4), generate_step(prompt, model)) - toks, all_logits = zip(*(r[1] for r in results)) + results = zip(range(4), generate_step(prompt[None], model)) + results = [(t.item(), l.squeeze(0)) for _, (t, l) in results] + toks, all_logits = zip(*results) prompt_cache = make_prompt_cache(model) i = 0 for _, (tok, logits) in zip( - range(2), generate_step(prompt, model, prompt_cache=prompt_cache) + range(2), generate_step(prompt[None], model, prompt_cache=prompt_cache) ): + tok, logits = tok.item(), logits.squeeze(0) self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i])) i += 1 @@ -294,8 +300,9 @@ class TestPromptCache(unittest.TestCase): for _, (tok, logits) in zip( range(1), - generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache), + generate_step(mx.array([[toks[i]]]), model, prompt_cache=prompt_cache), ): + tok, logits = tok.item(), logits.squeeze(0) i += 1 self.assertEqual(tok, toks[i]) self.assertTrue(mx.allclose(logits, all_logits[i], rtol=2e-2))