Add from_draft field in GenerationResponse

This commit is contained in:
Matt Clayton 2025-02-10 12:00:11 -05:00
parent 5865899c81
commit 93591970cf

View File

@ -13,7 +13,7 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Type, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -65,6 +65,7 @@ class GenerationResponse:
Args: Args:
text (str): The next segment of decoded text. This can be an empty string. text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token. token (int): The next token.
from_draft (bool): Whether the token was generated by a draft model.
logprobs (mx.array): A vector of log probabilities. logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt. prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second. prompt_tps (float): The prompt processing tokens-per-second.
@ -76,6 +77,7 @@ class GenerationResponse:
text: str text: str
token: int token: int
from_draft: bool = False
logprobs: mx.array logprobs: mx.array
prompt_tokens: int prompt_tokens: int
prompt_tps: float prompt_tps: float
@ -205,6 +207,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
group_size=kv_group_size, bits=kv_bits group_size=kv_group_size, bits=kv_bits
) )
class TokenMetadata(NamedTuple):
from_draft: bool = False
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
@ -220,7 +224,7 @@ def generate_step(
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None, prompt_progress_callback: Optional[Callable[int, int]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -248,7 +252,7 @@ def generate_step(
prompt tokens processed so far and the total number of prompt tokens. prompt tokens processed so far and the total number of prompt tokens.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array]: One token, a vector of log probabilities, and token metadata.
""" """
y = prompt y = prompt
@ -323,7 +327,6 @@ def generate_step(
y, logprobs = next_y, next_logprobs y, logprobs = next_y, next_logprobs
n += 1 n += 1
def speculative_generate_step( def speculative_generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
@ -338,7 +341,7 @@ def speculative_generate_step(
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -365,7 +368,7 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, and token metadata.
""" """
y = prompt y = prompt
@ -450,12 +453,12 @@ def speculative_generate_step(
break break
n += 1 n += 1
ntoks += 1 ntoks += 1
yield tn, lpn yield tn, lpn, TokenMetadata(from_draft=True)
if ntoks == max_tokens: if ntoks == max_tokens:
break break
if ntoks < max_tokens: if ntoks < max_tokens:
ntoks += 1 ntoks += 1
yield tokens[n], logprobs[n] yield tokens[n], logprobs[n], TokenMetadata(from_draft=False)
if ntoks == max_tokens: if ntoks == max_tokens:
break break
@ -526,7 +529,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs) in enumerate(token_generator): for n, (token, logprobs, token_metadata) in enumerate(token_generator):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
@ -539,6 +542,7 @@ def stream_generate(
yield GenerationResponse( yield GenerationResponse(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
from_draft=token_metadata.from_draft,
logprobs=logprobs, logprobs=logprobs,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
@ -552,6 +556,7 @@ def stream_generate(
yield GenerationResponse( yield GenerationResponse(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
from_draft=token_metadata.from_draft,
logprobs=logprobs, logprobs=logprobs,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,