From cfc29c29f45372c78876335a44b0c99ab6565ae0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 25 Nov 2024 09:47:00 -0800 Subject: [PATCH] Put prompt processing in same stream (#1122) * put prompt processing in same stream * patch --- llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/utils.py | 13 +++++++------ 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 5168eee4..343e0016 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.20.0" +__version__ = "0.20.1" diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5abd396d..0e2f7af7 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -274,13 +274,14 @@ def generate_step( y = sampler(logprobs) return y, logprobs.squeeze(0) - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - y = y[prefill_step_size:] - mx.metal.clear_cache() + with mx.stream(generation_stream): + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + y = y[prefill_step_size:] + mx.metal.clear_cache() - y, logprobs = _step(y) + y, logprobs = _step(y) mx.async_eval(y, logprobs) n = 0