mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add finish_reason in GenerationResponse (#1153)
This commit is contained in:
parent
77b42b7c8b
commit
06af3c9b0e
@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
||||
from mlx.utils import tree_flatten, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
@ -59,6 +59,7 @@ class GenerationResponse:
|
||||
generation_tokens (int): The number of generated tokens.
|
||||
generation_tps (float): The tokens-per-second for generation.
|
||||
peak_memory (float): The peak memory used so far in GB.
|
||||
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
|
||||
"""
|
||||
|
||||
text: str
|
||||
@ -69,6 +70,7 @@ class GenerationResponse:
|
||||
generation_tokens: int
|
||||
generation_tps: float
|
||||
peak_memory: float
|
||||
finish_reason: Optional[str] = None
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
@ -375,6 +377,7 @@ def stream_generate(
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||
finish_reason=None,
|
||||
)
|
||||
|
||||
detokenizer.finalize()
|
||||
@ -387,6 +390,7 @@ def stream_generate(
|
||||
generation_tokens=n + 1,
|
||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||
finish_reason="stop" if token in tokenizer.eos_token_ids else "length",
|
||||
)
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user