diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index cfbcf29e..1e07546e 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -242,6 +242,7 @@ def generate_step( model(y[:prefill_step_size][None], cache=prompt_cache) mx.eval([c.state for c in prompt_cache]) y = y[prefill_step_size:] + mx.metal.clear_cache() y, logprobs = _step(y)