From c799133998a943affdc395c5cb159ead2576a225 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 14 Oct 2024 10:25:24 -0700 Subject: [PATCH] Make llm async eval less brittle (#1040) * Make llm async eval less brittle * nit --- llms/mlx_lm/utils.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1e07546e..4f872982 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -246,10 +246,10 @@ def generate_step( y, logprobs = _step(y) - mx.async_eval(y) + mx.async_eval(y, logprobs) while True: next_y, next_logprobs = _step(y) - mx.async_eval(next_y) + mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs y, logprobs = next_y, next_logprobs @@ -348,7 +348,9 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) + with mx.stream(mx.cpu): + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True)