diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 496ae4fc..6e9c7ded 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -274,13 +274,14 @@ def generate_step( y = sampler(logprobs) return y, logprobs.squeeze(0) - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - y = y[prefill_step_size:] - mx.metal.clear_cache() + with mx.stream(generation_stream): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() - y, logprobs = _step(y) + y, logprobs = _step(y) mx.async_eval(y, logprobs) n = 0