mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Add from_draft field in GenerationResponse
This commit is contained in:
parent
5865899c81
commit
93591970cf
@ -13,7 +13,7 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Type, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -65,6 +65,7 @@ class GenerationResponse:
|
|||||||
Args:
|
Args:
|
||||||
text (str): The next segment of decoded text. This can be an empty string.
|
text (str): The next segment of decoded text. This can be an empty string.
|
||||||
token (int): The next token.
|
token (int): The next token.
|
||||||
|
from_draft (bool): Whether the token was generated by a draft model.
|
||||||
logprobs (mx.array): A vector of log probabilities.
|
logprobs (mx.array): A vector of log probabilities.
|
||||||
prompt_tokens (int): The number of tokens in the prompt.
|
prompt_tokens (int): The number of tokens in the prompt.
|
||||||
prompt_tps (float): The prompt processing tokens-per-second.
|
prompt_tps (float): The prompt processing tokens-per-second.
|
||||||
@ -76,6 +77,7 @@ class GenerationResponse:
|
|||||||
|
|
||||||
text: str
|
text: str
|
||||||
token: int
|
token: int
|
||||||
|
from_draft: bool = False
|
||||||
logprobs: mx.array
|
logprobs: mx.array
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
prompt_tps: float
|
prompt_tps: float
|
||||||
@ -205,6 +207,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
|||||||
group_size=kv_group_size, bits=kv_bits
|
group_size=kv_group_size, bits=kv_bits
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class TokenMetadata(NamedTuple):
|
||||||
|
from_draft: bool = False
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
@ -220,7 +224,7 @@ def generate_step(
|
|||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
@ -248,7 +252,7 @@ def generate_step(
|
|||||||
prompt tokens processed so far and the total number of prompt tokens.
|
prompt tokens processed so far and the total number of prompt tokens.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
Tuple[mx.array, mx.array]: One token, a vector of log probabilities, and token metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
@ -323,7 +327,6 @@ def generate_step(
|
|||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
n += 1
|
n += 1
|
||||||
|
|
||||||
|
|
||||||
def speculative_generate_step(
|
def speculative_generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -338,7 +341,7 @@ def speculative_generate_step(
|
|||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
@ -365,7 +368,7 @@ def speculative_generate_step(
|
|||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, and token metadata.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
@ -450,12 +453,12 @@ def speculative_generate_step(
|
|||||||
break
|
break
|
||||||
n += 1
|
n += 1
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
yield tn, lpn
|
yield tn, lpn, TokenMetadata(from_draft=True)
|
||||||
if ntoks == max_tokens:
|
if ntoks == max_tokens:
|
||||||
break
|
break
|
||||||
if ntoks < max_tokens:
|
if ntoks < max_tokens:
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
yield tokens[n], logprobs[n]
|
yield tokens[n], logprobs[n], TokenMetadata(from_draft=False)
|
||||||
|
|
||||||
if ntoks == max_tokens:
|
if ntoks == max_tokens:
|
||||||
break
|
break
|
||||||
@ -526,7 +529,7 @@ def stream_generate(
|
|||||||
with wired_limit(model, [generation_stream]):
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for n, (token, logprobs) in enumerate(token_generator):
|
for n, (token, logprobs, token_metadata) in enumerate(token_generator):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
@ -539,6 +542,7 @@ def stream_generate(
|
|||||||
yield GenerationResponse(
|
yield GenerationResponse(
|
||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
|
from_draft=token_metadata.from_draft,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
prompt_tokens=prompt.size,
|
prompt_tokens=prompt.size,
|
||||||
prompt_tps=prompt_tps,
|
prompt_tps=prompt_tps,
|
||||||
@ -552,6 +556,7 @@ def stream_generate(
|
|||||||
yield GenerationResponse(
|
yield GenerationResponse(
|
||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
|
from_draft=token_metadata.from_draft,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
prompt_tokens=prompt.size,
|
prompt_tokens=prompt.size,
|
||||||
prompt_tps=prompt_tps,
|
prompt_tps=prompt_tps,
|
||||||
|
Loading…
Reference in New Issue
Block a user