Re-work for minimal changes, add test

This commit is contained in:
Matt Clayton 2025-02-10 12:30:12 -05:00
parent fff5daeb85
commit 4df23e961c
2 changed files with 66 additions and 18 deletions

View File

@ -13,7 +13,18 @@ import time
from dataclasses import dataclass
from pathlib import Path
from textwrap import dedent
from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Type, Union
from typing import (
Any,
Callable,
Dict,
Generator,
List,
NamedTuple,
Optional,
Tuple,
Type,
Union,
)
import mlx.core as mx
import mlx.nn as nn
@ -77,8 +88,8 @@ class GenerationResponse:
text: str
token: int
from_draft: bool = False
logprobs: mx.array
from_draft: bool
prompt_tokens: int
prompt_tps: float
generation_tokens: int
@ -207,8 +218,6 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
group_size=kv_group_size, bits=kv_bits
)
class TokenMetadata(NamedTuple):
from_draft: bool = False
def generate_step(
prompt: mx.array,
@ -224,7 +233,7 @@ def generate_step(
kv_group_size: int = 64,
quantized_kv_start: int = 0,
prompt_progress_callback: Optional[Callable[int, int]] = None,
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@ -252,8 +261,7 @@ def generate_step(
prompt tokens processed so far and the total number of prompt tokens.
Yields:
Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log
probabilities, and token metadata.
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
@ -328,6 +336,7 @@ def generate_step(
y, logprobs = next_y, next_logprobs
n += 1
def speculative_generate_step(
prompt: mx.array,
model: nn.Module,
@ -342,7 +351,7 @@ def speculative_generate_step(
kv_bits: Optional[int] = None,
kv_group_size: int = 64,
quantized_kv_start: int = 0,
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
@ -369,8 +378,8 @@ def speculative_generate_step(
when ``kv_bits`` is non-None. Default: ``0``.
Yields:
Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log
probabilities, and token metadata.
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
and a bool indicating if the token was generated by the draft model
"""
y = prompt
@ -455,12 +464,12 @@ def speculative_generate_step(
break
n += 1
ntoks += 1
yield tn, lpn, TokenMetadata(from_draft=True)
yield tn, lpn, True
if ntoks == max_tokens:
break
if ntoks < max_tokens:
ntoks += 1
yield tokens[n], logprobs[n], TokenMetadata(from_draft=False)
yield tokens[n], logprobs[n], False
if ntoks == max_tokens:
break
@ -523,6 +532,10 @@ def stream_generate(
if draft_model is None:
kwargs.pop("num_draft_tokens", None)
token_generator = generate_step(prompt, model, **kwargs)
# from_draft always false for non-speculative generation
token_generator = (
(token, logprobs, False) for token, logprobs in token_generator
)
else:
kwargs.pop("max_kv_size", None)
token_generator = speculative_generate_step(
@ -531,7 +544,7 @@ def stream_generate(
with wired_limit(model, [generation_stream]):
detokenizer.reset()
tic = time.perf_counter()
for n, (token, logprobs, token_metadata) in enumerate(token_generator):
for n, (token, logprobs, from_draft) in enumerate(token_generator):
if n == 0:
prompt_time = time.perf_counter() - tic
prompt_tps = prompt.size / prompt_time
@ -544,8 +557,8 @@ def stream_generate(
yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
from_draft=token_metadata.from_draft,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,
@ -558,8 +571,8 @@ def stream_generate(
yield GenerationResponse(
text=detokenizer.last_segment,
token=token,
from_draft=token_metadata.from_draft,
logprobs=logprobs,
from_draft=from_draft,
prompt_tokens=prompt.size,
prompt_tps=prompt_tps,
generation_tokens=n + 1,

View File

@ -1,17 +1,24 @@
# Copyright © 2024 Apple Inc.
import unittest
from typing import List
from mlx_lm.sample_utils import make_logits_processors
from mlx_lm.utils import generate, load
from mlx_lm.utils import (
GenerationResponse,
generate,
load,
make_sampler,
stream_generate,
)
class TestGenerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
def test_generate(self):
# Simple test that generation runs
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
def test_stream_generate_speculative(self):
# Use same model as draft model, this is not a speed test
draft_model, _ = load(self.HF_MODEL_PATH)
results: List[GenerationResponse] = []
drafted: List[bool] = []
# make a determinate sampler
sampler = make_sampler(temp=0.0)
for generation_result in stream_generate(
model=self.model,
tokenizer=self.tokenizer,
prompt="hello",
max_tokens=5,
draft_model=draft_model,
num_draft_tokens=2,
sampler=sampler,
):
drafted.append(generation_result.from_draft)
results.append(generation_result)
self.assertEqual(len(results), 5)
# since num_draft_tokens is 2 and draft model is the same, the
# first 2 generations should be drafts, the third should come
# from the target model, and last two should be drafts
self.assertEqual(drafted, [True, True, False, True, True])
if __name__ == "__main__":
unittest.main()