diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 1e07546e..4f872982 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -246,10 +246,10 @@ def generate_step( y, logprobs = _step(y) - mx.async_eval(y) + mx.async_eval(y, logprobs) while True: next_y, next_logprobs = _step(y) - mx.async_eval(next_y) + mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs y, logprobs = next_y, next_logprobs @@ -348,7 +348,9 @@ def generate( if formatter: # We have to finalize so that the prob corresponds to the last segment detokenizer.finalize() - formatter(detokenizer.last_segment, mx.exp(logprobs[token]).item()) + with mx.stream(mx.cpu): + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) else: print(detokenizer.last_segment, end="", flush=True)