Re-work for minimal changes, add test

This commit is contained in:
Matt Clayton 2025-02-10 12:30:12 -05:00
parent fff5daeb85
commit 4df23e961c
2 changed files with 66 additions and 18 deletions

View File

@ -13,7 +13,18 @@ import time
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from textwrap import dedent from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Type, Union from typing import (
Any,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -77,8 +88,8 @@ class GenerationResponse:
text: str text: str
token: int token: int
from_draft: bool = False
logprobs: mx.array logprobs: mx.array
from_draft: bool
prompt_tokens: int prompt_tokens: int
prompt_tps: float prompt_tps: float
generation_tokens: int generation_tokens: int
@ -207,8 +218,6 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
group_size=kv_group_size, bits=kv_bits group_size=kv_group_size, bits=kv_bits
) )
class TokenMetadata(NamedTuple):
from_draft: bool = False
def generate_step( def generate_step(
prompt: mx.array, prompt: mx.array,
@ -224,7 +233,7 @@ def generate_step(
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None, prompt_progress_callback: Optional[Callable[int, int]] = None,
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]: ) -> Generator[Tuple[mx.array, mx.array], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -252,8 +261,7 @@ def generate_step(
prompt tokens processed so far and the total number of prompt tokens. prompt tokens processed so far and the total number of prompt tokens.
Yields: Yields:
Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
probabilities, and token metadata.
""" """
y = prompt y = prompt
@ -328,6 +336,7 @@ def generate_step(
y, logprobs = next_y, next_logprobs y, logprobs = next_y, next_logprobs
n += 1 n += 1
def speculative_generate_step( def speculative_generate_step(
prompt: mx.array, prompt: mx.array,
model: nn.Module, model: nn.Module,
@ -342,7 +351,7 @@ def speculative_generate_step(
kv_bits: Optional[int] = None, kv_bits: Optional[int] = None,
kv_group_size: int = 64, kv_group_size: int = 64,
quantized_kv_start: int = 0, quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]: ) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
""" """
A generator producing token ids based on the given prompt from the model. A generator producing token ids based on the given prompt from the model.
@ -369,8 +378,8 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``. when ``kv_bits`` is non-None. Default: ``0``.
Yields: Yields:
Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
probabilities, and token metadata. and a bool indicating if the token was generated by the draft model
""" """
y = prompt y = prompt
@ -455,12 +464,12 @@ def speculative_generate_step(
break break
n += 1 n += 1
ntoks += 1 ntoks += 1
yield tn, lpn, TokenMetadata(from_draft=True) yield tn, lpn, True
if ntoks == max_tokens: if ntoks == max_tokens:
break break
if ntoks < max_tokens: if ntoks < max_tokens:
ntoks += 1 ntoks += 1
yield tokens[n], logprobs[n], TokenMetadata(from_draft=False) yield tokens[n], logprobs[n], False
if ntoks == max_tokens: if ntoks == max_tokens:
break break
@ -523,6 +532,10 @@ def stream_generate(
if draft_model is None: if draft_model is None:
kwargs.pop("num_draft_tokens", None) kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs) token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
else: else:
kwargs.pop("max_kv_size", None) kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step( token_generator = speculative_generate_step(
@ -531,7 +544,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]): with wired_limit(model, [generation_stream]):
detokenizer.reset() detokenizer.reset()
tic = time.perf_counter() tic = time.perf_counter()
for n, (token, logprobs, token_metadata) in enumerate(token_generator): for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0: if n == 0:
prompt_time = time.perf_counter() - tic prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time prompt_tps = prompt.size / prompt_time
@ -544,8 +557,8 @@ def stream_generate(
yield GenerationResponse( yield GenerationResponse(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
from_draft=token_metadata.from_draft,
logprobs=logprobs, logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
generation_tokens=n + 1, generation_tokens=n + 1,
@ -558,8 +571,8 @@ def stream_generate(
yield GenerationResponse( yield GenerationResponse(
text=detokenizer.last_segment, text=detokenizer.last_segment,
token=token, token=token,
from_draft=token_metadata.from_draft,
logprobs=logprobs, logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size, prompt_tokens=prompt.size,
prompt_tps=prompt_tps, prompt_tps=prompt_tps,
generation_tokens=n + 1, generation_tokens=n + 1,

View File

@ -1,17 +1,24 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import unittest import unittest
from typing import List
from mlx_lm.sample_utils import make_logits_processors from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load from mlx_lm.utils import (
GenerationResponse,
generate,
load,
make_sampler,
stream_generate,
)
class TestGenerate(unittest.TestCase): class TestGenerate(unittest.TestCase):
@classmethod @classmethod
def setUpClass(cls): def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH) cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
def test_generate(self): def test_generate(self):
# Simple test that generation runs # Simple test that generation runs
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
) )
self.assertEqual(len(all_toks), len(init_toks) + 5) self.assertEqual(len(all_toks), len(init_toks) + 5)
def test_stream_generate_speculative(self):
# Use same model as draft model, this is not a speed test
draft_model, _ = load(self.HF_MODEL_PATH)
results: List[GenerationResponse] = []
drafted: List[bool] = []
# make a determinate sampler
sampler = make_sampler(temp=0.0)
for generation_result in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt="hello",
max_tokens=5,
draft_model=draft_model,
num_draft_tokens=2,
sampler=sampler,
):
drafted.append(generation_result.from_draft)
results.append(generation_result)
self.assertEqual(len(results), 5)
# since num_draft_tokens is 2 and draft model is the same, the
# first 2 generations should be drafts, the third should come
# from the target model, and last two should be drafts
self.assertEqual(drafted, [True, True, False, True, True])
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()