diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 8649fbe3..cfbcf29e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -239,8 +239,8 @@ def generate_step( return y, logprobs.squeeze(0) while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=cache) - mx.eval([c.state for c in cache]) + model(y[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) y = y[prefill_step_size:] y, logprobs = _step(y)