mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Add finish_reason in GenerationResponse (#1153)
This commit is contained in:
parent
77b42b7c8b
commit
06af3c9b0e
@ -16,7 +16,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_flatten, tree_map, tree_reduce
|
from mlx.utils import tree_flatten, tree_reduce
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
@ -59,6 +59,7 @@ class GenerationResponse:
|
|||||||
generation_tokens (int): The number of generated tokens.
|
generation_tokens (int): The number of generated tokens.
|
||||||
generation_tps (float): The tokens-per-second for generation.
|
generation_tps (float): The tokens-per-second for generation.
|
||||||
peak_memory (float): The peak memory used so far in GB.
|
peak_memory (float): The peak memory used so far in GB.
|
||||||
|
finish_reason (str): The reason the response is being sent: "length", "stop" or `None`
|
||||||
"""
|
"""
|
||||||
|
|
||||||
text: str
|
text: str
|
||||||
@ -69,6 +70,7 @@ class GenerationResponse:
|
|||||||
generation_tokens: int
|
generation_tokens: int
|
||||||
generation_tps: float
|
generation_tps: float
|
||||||
peak_memory: float
|
peak_memory: float
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
@ -375,6 +377,7 @@ def stream_generate(
|
|||||||
generation_tokens=n + 1,
|
generation_tokens=n + 1,
|
||||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||||
|
finish_reason=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
detokenizer.finalize()
|
detokenizer.finalize()
|
||||||
@ -387,6 +390,7 @@ def stream_generate(
|
|||||||
generation_tokens=n + 1,
|
generation_tokens=n + 1,
|
||||||
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
generation_tps=(n + 1) / (time.perf_counter() - tic),
|
||||||
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
peak_memory=mx.metal.get_peak_memory() / 1e9,
|
||||||
|
finish_reason="stop" if token in tokenizer.eos_token_ids else "length",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user