From 3586c876aa4a014724c7d4bab160ae5a3b10093a Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Nov 2024 08:06:12 -0800 Subject: [PATCH] put prompt processing in same stream --- llms/mlx_lm/utils.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 496ae4fc..6e9c7ded 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -274,13 +274,14 @@ def generate_step( y = sampler(logprobs) return y, logprobs.squeeze(0) - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - y = y[prefill_step_size:] - mx.metal.clear_cache() + with mx.stream(generation_stream): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() - y, logprobs = _step(y) + y, logprobs = _step(y) mx.async_eval(y, logprobs) n = 0