Add from_draft field in GenerationResponse

This commit is contained in:
Matt Clayton 2025-02-10 12:00:11 -05:00
parent 5865899c81
commit 93591970cf

View File

@ -13,7 +13,7 @@ import time
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Type, Union
import mlx.core as mx
import mlx.nn as nn
@ -65,6 +65,7 @@ class GenerationResponse:
Args:
text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token.
from_draft (bool): Whether the token was generated by a draft model.
logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second.
@ -76,6 +77,7 @@ class GenerationResponse:
text: str
token: int
from_draft: bool = False
logprobs: mx.array
prompt_tokens: int
prompt_tps: float
@ -205,6 +207,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
group_size=kv_group_size, bits=kv_bits
)
class TokenMetadata(NamedTuple):
from_draft: bool = False
def generate_step(
prompt: mx.array,
@ -220,7 +224,7 @@ def generate_step(
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@ -248,7 +252,7 @@ def generate_step(
prompt tokens processed so far and the total number of prompt tokens.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
Tuple[mx.array, mx.array]: One token, a vector of log probabilities, and token metadata.
"""
y = prompt
@ -323,7 +327,6 @@ def generate_step(
y, logprobs = next_y, next_logprobs
n += 1
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
@ -338,7 +341,7 @@ def speculative_generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@ -365,7 +368,7 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, and token metadata.
"""
y = prompt
@ -450,12 +453,12 @@ def speculative_generate_step(
break
n += 1
ntoks += 1
yield tn, lpn
yield tn, lpn, TokenMetadata(from_draft=True)
if ntoks == max_tokens:
break
if ntoks < max_tokens:
ntoks += 1
yield tokens[n], logprobs[n]
yield tokens[n], logprobs[n], TokenMetadata(from_draft=False)
if ntoks == max_tokens:
break
@ -526,7 +529,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]):
detokenizer.reset()
tic = time.perf_counter()
for n, (token, logprobs) in enumerate(token_generator):
for n, (token, logprobs, token_metadata) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
@ -539,6 +542,7 @@ def stream_generate(
yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
from_draft=token_metadata.from_draft,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
@ -552,6 +556,7 @@ def stream_generate(
yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
from_draft=token_metadata.from_draft,
logprobs=logprobs,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,