diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 06784f10..b9fc202d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -310,10 +310,14 @@ def generate_step( y, logprobs = _step(y) mx.async_eval(y, logprobs) + n = 0 while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs + if n % 256 == 0: + mx.metal.clear_cache() + n += 1 y, logprobs = next_y, next_logprobs