From 06af3c9b0eac1aea927dcbbda66cacd5aab76f4a Mon Sep 17 00:00:00 2001 From: madroid Date: Fri, 13 Dec 2024 02:37:40 +0800 Subject: [PATCH] Add finish_reason in GenerationResponse (#1153) --- llms/mlx_lm/utils.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index d81bb66a..493c1c42 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten, tree_map, tree_reduce +from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -59,6 +59,7 @@ class GenerationResponse: generation_tokens (int): The number of generated tokens. generation_tps (float): The tokens-per-second for generation. peak_memory (float): The peak memory used so far in GB. + finish_reason (str): The reason the response is being sent: "length", "stop" or `None` """ text: str @@ -69,6 +70,7 @@ class GenerationResponse: generation_tokens: int generation_tps: float peak_memory: float + finish_reason: Optional[str] = None @contextlib.contextmanager @@ -375,6 +377,7 @@ def stream_generate( generation_tokens=n + 1, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.metal.get_peak_memory() / 1e9, + finish_reason=None, ) detokenizer.finalize() @@ -387,6 +390,7 @@ def stream_generate( generation_tokens=n + 1, generation_tps=(n + 1) / (time.perf_counter() - tic), peak_memory=mx.metal.get_peak_memory() / 1e9, + finish_reason="stop" if token in tokenizer.eos_token_ids else "length", )