Add "from_draft" to GenerationResponse (#1272)

* Add from_draft field in GenerationResponse

* Cleanup

* Re-work for minimal changes, add test

* Fix comment
This commit is contained in:
Matt Clayton 2025-02-11 18:41:02 -05:00 committed by GitHub
parent bded1a8fcd
commit 3d677f0870
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 65 additions and 10 deletions

View File

@ -13,7 +13,18 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union from typing import (
Any,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -65,6 +76,7 @@ class GenerationResponse:
Args: Args:
text (str): The next segment of decoded text. This can be an empty string. text (str): The next segment of decoded text. This can be an empty string.
token (int): The next token. token (int): The next token.
from_draft (bool): Whether the token was generated by the draft model.
logprobs (mx.array): A vector of log probabilities. logprobs (mx.array): A vector of log probabilities.
prompt_tokens (int): The number of tokens in the prompt. prompt_tokens (int): The number of tokens in the prompt.
prompt_tps (float): The prompt processing tokens-per-second. prompt_tps (float): The prompt processing tokens-per-second.
@ -77,6 +89,7 @@ class GenerationResponse:
text: str text: str
token: int token: int
logprobs: mx.array logprobs: mx.array
from_draft: bool
prompt_tokens: int prompt_tokens: int
prompt_tps: float prompt_tps: float
generation_tokens: int generation_tokens: int
@ -338,7 +351,7 @@ def speculative_generate_step(
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array], None, None]: ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -365,7 +378,8 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities. Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
and a bool indicating if the token was generated by the draft model
""" """
y = prompt y = prompt
@ -450,12 +464,12 @@ def speculative_generate_step(
break break
n += 1 n += 1
ntoks += 1 ntoks += 1
yield tn, lpn yield tn, lpn, True
if ntoks == max_tokens: if ntoks == max_tokens:
break break
if ntoks < max_tokens: if ntoks < max_tokens:
ntoks += 1 ntoks += 1
yield tokens[n], logprobs[n] yield tokens[n], logprobs[n], False
if ntoks == max_tokens: if ntoks == max_tokens:
break break
@ -463,7 +477,7 @@ def speculative_generate_step(
y = mx.array([tokens[n]], mx.uint32) y = mx.array([tokens[n]], mx.uint32)
draft_y = y draft_y = y
# If we accpeted all the draft tokens, include the last # If we accepted all the draft tokens, include the last
# draft token in the next draft step since it hasn't been # draft token in the next draft step since it hasn't been
# processed yet by the draft model # processed yet by the draft model
if n == num_draft: if n == num_draft:
@ -518,6 +532,10 @@ def stream_generate(
if draft_model is None: if draft_model is None:
kwargs.pop("num_draft_tokens", None) kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs) token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
else: else:
kwargs.pop("max_kv_size", None) kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step( token_generator = speculative_generate_step(
@ -526,7 +544,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs) in enumerate(token_generator): for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
@ -540,6 +558,7 @@ def stream_generate(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
logprobs=logprobs, logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
generation_tokens=n + 1, generation_tokens=n + 1,
@ -553,6 +572,7 @@ def stream_generate(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
logprobs=logprobs, logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
generation_tokens=n + 1, generation_tokens=n + 1,

View File

@ -1,17 +1,24 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import unittest import unittest
from typing import List
from mlx_lm.sample_utils import make_logits_processors from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load from mlx_lm.utils import (
GenerationResponse,
generate,
load,
make_sampler,
stream_generate,
)
class TestGenerate(unittest.TestCase): class TestGenerate(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH) cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
def test_generate(self): def test_generate(self):
# Simple test that generation runs # Simple test that generation runs
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
) )
self.assertEqual(len(all_toks), len(init_toks) + 5) self.assertEqual(len(all_toks), len(init_toks) + 5)
def test_stream_generate_speculative(self):
# Use same model as draft model, this is not a speed test
draft_model, _ = load(self.HF_MODEL_PATH)
results: List[GenerationResponse] = []
drafted: List[bool] = []
# make a determinate sampler
sampler = make_sampler(temp=0.0)
for generation_result in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt="hello",
max_tokens=5,
draft_model=draft_model,
num_draft_tokens=2,
sampler=sampler,
):
drafted.append(generation_result.from_draft)
results.append(generation_result)
self.assertEqual(len(results), 5)
# since num_draft_tokens is 2 and draft model is the same, the
# first 2 generations should be drafts, the third should come
# from the target model, and last two should be drafts
self.assertEqual(drafted, [True, True, False, True, True])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()