Add finish_reason in GenerationResponse (#1153)

This commit is contained in:
madroid 2024-12-13 02:37:40 +08:00 committed by GitHub
parent 77b42b7c8b
commit 06af3c9b0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_map, tree_reduce from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
@ -59,6 +59,7 @@ class GenerationResponse:
generation_tokens (int): The number of generated tokens. generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation. generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB. peak_memory (float): The peak memory used so far in GB.
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
""" """
text: str text: str
@ -69,6 +70,7 @@ class GenerationResponse:
generation_tokens: int generation_tokens: int
generation_tps: float generation_tps: float
peak_memory: float peak_memory: float
finish_reason: Optional[str] = None
@contextlib.contextmanager @contextlib.contextmanager
@ -375,6 +377,7 @@ def stream_generate(
generation_tokens=n + 1, generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic), generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9, peak_memory=mx.metal.get_peak_memory() / 1e9,
finish_reason=None,
) )
detokenizer.finalize() detokenizer.finalize()
@ -387,6 +390,7 @@ def stream_generate(
generation_tokens=n + 1, generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic), generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9, peak_memory=mx.metal.get_peak_memory() / 1e9,
finish_reason="stop" if token in tokenizer.eos_token_ids else "length",
) )