Add finish_reason in GenerationResponse (#1153)

This commit is contained in:
madroid 2024-12-13 02:37:40 +08:00 committed by GitHub
parent 77b42b7c8b
commit 06af3c9b0e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten, tree_map, tree_reduce
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
# Local imports
@ -59,6 +59,7 @@ class GenerationResponse:
generation_tokens (int): The number of generated tokens.
generation_tps (float): The tokens-per-second for generation.
peak_memory (float): The peak memory used so far in GB.
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
"""
text: str
@ -69,6 +70,7 @@ class GenerationResponse:
generation_tokens: int
generation_tps: float
peak_memory: float
finish_reason: Optional[str] = None
@contextlib.contextmanager
@ -375,6 +377,7 @@ def stream_generate(
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
finish_reason=None,
)
detokenizer.finalize()
@ -387,6 +390,7 @@ def stream_generate(
generation_tokens=n + 1,
generation_tps=(n + 1) / (time.perf_counter() - tic),
peak_memory=mx.metal.get_peak_memory() / 1e9,
finish_reason="stop" if token in tokenizer.eos_token_ids else "length",
)