diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 5ee043a3..4d815810 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -345,9 +345,7 @@ def stream_generate( with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in enumerate( - generate_step(prompt, model, **kwargs), - ): + for n, (token, logprobs) in enumerate(generate_step(prompt, model, **kwargs)): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time