diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 853e1c95..240e5dd9 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -300,10 +300,9 @@ def stream_generate( range(max_tokens), generate_step(prompt_tokens, model, **kwargs), ): - if token == tokenizer.eos_token_id: - break detokenizer.add_token(token) - + if n == (max_tokens - 1) or token == tokenizer.eos_token_id: + break # Yield the last segment if streaming yield detokenizer.last_segment, token, logits