diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b2e89a13..005c877a 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -13,7 +13,7 @@ import time from dataclasses import dataclass from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Type, Union import mlx.core as mx import mlx.nn as nn @@ -65,6 +65,7 @@ class GenerationResponse: Args: text (str): The next segment of decoded text. This can be an empty string. token (int): The next token. + from_draft (bool): Whether the token was generated by a draft model. logprobs (mx.array): A vector of log probabilities. prompt_tokens (int): The number of tokens in the prompt. prompt_tps (float): The prompt processing tokens-per-second. @@ -76,6 +77,7 @@ class GenerationResponse: text: str token: int + from_draft: bool = False logprobs: mx.array prompt_tokens: int prompt_tps: float @@ -205,6 +207,8 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ group_size=kv_group_size, bits=kv_bits ) +class TokenMetadata(NamedTuple): + from_draft: bool = False def generate_step( prompt: mx.array, @@ -220,7 +224,7 @@ def generate_step( kv_group_size: int = 64, quantized_kv_start: int = 0, prompt_progress_callback: Optional[Callable[int, int]] = None, -) -> Generator[Tuple[mx.array, mx.array], None, None]: +) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -248,7 +252,7 @@ def generate_step( prompt tokens processed so far and the total number of prompt tokens. Yields: - Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + Tuple[mx.array, mx.array]: One token, a vector of log probabilities, and token metadata. """ y = prompt @@ -323,7 +327,6 @@ def generate_step( y, logprobs = next_y, next_logprobs n += 1 - def speculative_generate_step( prompt: mx.array, model: nn.Module, @@ -338,7 +341,7 @@ def speculative_generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, -) -> Generator[Tuple[mx.array, mx.array], None, None]: +) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -365,7 +368,7 @@ def speculative_generate_step( when ``kv_bits`` is non-None. Default: ``0``. Yields: - Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, and token metadata. """ y = prompt @@ -450,12 +453,12 @@ def speculative_generate_step( break n += 1 ntoks += 1 - yield tn, lpn + yield tn, lpn, TokenMetadata(from_draft=True) if ntoks == max_tokens: break if ntoks < max_tokens: ntoks += 1 - yield tokens[n], logprobs[n] + yield tokens[n], logprobs[n], TokenMetadata(from_draft=False) if ntoks == max_tokens: break @@ -526,7 +529,7 @@ def stream_generate( with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in enumerate(token_generator): + for n, (token, logprobs, token_metadata) in enumerate(token_generator): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time @@ -539,6 +542,7 @@ def stream_generate( yield GenerationResponse( text=detokenizer.last_segment, token=token, + from_draft=token_metadata.from_draft, logprobs=logprobs, prompt_tokens=prompt.size, prompt_tps=prompt_tps, @@ -552,6 +556,7 @@ def stream_generate( yield GenerationResponse( text=detokenizer.last_segment, token=token, + from_draft=token_metadata.from_draft, logprobs=logprobs, prompt_tokens=prompt.size, prompt_tps=prompt_tps,