mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 06:54:39 +08:00
Re-work for minimal changes, add test
This commit is contained in:
parent
fff5daeb85
commit
4df23e961c
@ -13,7 +13,18 @@ import time
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from textwrap import dedent
|
from textwrap import dedent
|
||||||
from typing import Any, Callable, Dict, Generator, List, NamedTuple, Optional, Tuple, Type, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generator,
|
||||||
|
List,
|
||||||
|
NamedTuple,
|
||||||
|
Optional,
|
||||||
|
Tuple,
|
||||||
|
Type,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -77,8 +88,8 @@ class GenerationResponse:
|
|||||||
|
|
||||||
text: str
|
text: str
|
||||||
token: int
|
token: int
|
||||||
from_draft: bool = False
|
|
||||||
logprobs: mx.array
|
logprobs: mx.array
|
||||||
|
from_draft: bool
|
||||||
prompt_tokens: int
|
prompt_tokens: int
|
||||||
prompt_tps: float
|
prompt_tps: float
|
||||||
generation_tokens: int
|
generation_tokens: int
|
||||||
@ -207,8 +218,6 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
|
|||||||
group_size=kv_group_size, bits=kv_bits
|
group_size=kv_group_size, bits=kv_bits
|
||||||
)
|
)
|
||||||
|
|
||||||
class TokenMetadata(NamedTuple):
|
|
||||||
from_draft: bool = False
|
|
||||||
|
|
||||||
def generate_step(
|
def generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
@ -224,7 +233,7 @@ def generate_step(
|
|||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
prompt_progress_callback: Optional[Callable[int, int]] = None,
|
||||||
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
@ -252,8 +261,7 @@ def generate_step(
|
|||||||
prompt tokens processed so far and the total number of prompt tokens.
|
prompt tokens processed so far and the total number of prompt tokens.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log
|
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||||
probabilities, and token metadata.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
@ -328,6 +336,7 @@ def generate_step(
|
|||||||
y, logprobs = next_y, next_logprobs
|
y, logprobs = next_y, next_logprobs
|
||||||
n += 1
|
n += 1
|
||||||
|
|
||||||
|
|
||||||
def speculative_generate_step(
|
def speculative_generate_step(
|
||||||
prompt: mx.array,
|
prompt: mx.array,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
@ -342,7 +351,7 @@ def speculative_generate_step(
|
|||||||
kv_bits: Optional[int] = None,
|
kv_bits: Optional[int] = None,
|
||||||
kv_group_size: int = 64,
|
kv_group_size: int = 64,
|
||||||
quantized_kv_start: int = 0,
|
quantized_kv_start: int = 0,
|
||||||
) -> Generator[Tuple[mx.array, mx.array, TokenMetadata], None, None]:
|
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
|
||||||
"""
|
"""
|
||||||
A generator producing token ids based on the given prompt from the model.
|
A generator producing token ids based on the given prompt from the model.
|
||||||
|
|
||||||
@ -369,8 +378,8 @@ def speculative_generate_step(
|
|||||||
when ``kv_bits`` is non-None. Default: ``0``.
|
when ``kv_bits`` is non-None. Default: ``0``.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
Tuple[mx.array, mx.array, TokenMetadata]: One token, a vector of log
|
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
|
||||||
probabilities, and token metadata.
|
and a bool indicating if the token was generated by the draft model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
y = prompt
|
y = prompt
|
||||||
@ -455,12 +464,12 @@ def speculative_generate_step(
|
|||||||
break
|
break
|
||||||
n += 1
|
n += 1
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
yield tn, lpn, TokenMetadata(from_draft=True)
|
yield tn, lpn, True
|
||||||
if ntoks == max_tokens:
|
if ntoks == max_tokens:
|
||||||
break
|
break
|
||||||
if ntoks < max_tokens:
|
if ntoks < max_tokens:
|
||||||
ntoks += 1
|
ntoks += 1
|
||||||
yield tokens[n], logprobs[n], TokenMetadata(from_draft=False)
|
yield tokens[n], logprobs[n], False
|
||||||
|
|
||||||
if ntoks == max_tokens:
|
if ntoks == max_tokens:
|
||||||
break
|
break
|
||||||
@ -523,6 +532,10 @@ def stream_generate(
|
|||||||
if draft_model is None:
|
if draft_model is None:
|
||||||
kwargs.pop("num_draft_tokens", None)
|
kwargs.pop("num_draft_tokens", None)
|
||||||
token_generator = generate_step(prompt, model, **kwargs)
|
token_generator = generate_step(prompt, model, **kwargs)
|
||||||
|
# from_draft always false for non-speculative generation
|
||||||
|
token_generator = (
|
||||||
|
(token, logprobs, False) for token, logprobs in token_generator
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
kwargs.pop("max_kv_size", None)
|
kwargs.pop("max_kv_size", None)
|
||||||
token_generator = speculative_generate_step(
|
token_generator = speculative_generate_step(
|
||||||
@ -531,7 +544,7 @@ def stream_generate(
|
|||||||
with wired_limit(model, [generation_stream]):
|
with wired_limit(model, [generation_stream]):
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
for n, (token, logprobs, token_metadata) in enumerate(token_generator):
|
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||||
if n == 0:
|
if n == 0:
|
||||||
prompt_time = time.perf_counter() - tic
|
prompt_time = time.perf_counter() - tic
|
||||||
prompt_tps = prompt.size / prompt_time
|
prompt_tps = prompt.size / prompt_time
|
||||||
@ -544,8 +557,8 @@ def stream_generate(
|
|||||||
yield GenerationResponse(
|
yield GenerationResponse(
|
||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
from_draft=token_metadata.from_draft,
|
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
from_draft=from_draft,
|
||||||
prompt_tokens=prompt.size,
|
prompt_tokens=prompt.size,
|
||||||
prompt_tps=prompt_tps,
|
prompt_tps=prompt_tps,
|
||||||
generation_tokens=n + 1,
|
generation_tokens=n + 1,
|
||||||
@ -558,8 +571,8 @@ def stream_generate(
|
|||||||
yield GenerationResponse(
|
yield GenerationResponse(
|
||||||
text=detokenizer.last_segment,
|
text=detokenizer.last_segment,
|
||||||
token=token,
|
token=token,
|
||||||
from_draft=token_metadata.from_draft,
|
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
|
from_draft=from_draft,
|
||||||
prompt_tokens=prompt.size,
|
prompt_tokens=prompt.size,
|
||||||
prompt_tps=prompt_tps,
|
prompt_tps=prompt_tps,
|
||||||
generation_tokens=n + 1,
|
generation_tokens=n + 1,
|
||||||
|
@ -1,17 +1,24 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
import unittest
|
import unittest
|
||||||
|
from typing import List
|
||||||
|
|
||||||
from mlx_lm.sample_utils import make_logits_processors
|
from mlx_lm.sample_utils import make_logits_processors
|
||||||
from mlx_lm.utils import generate, load
|
from mlx_lm.utils import (
|
||||||
|
GenerationResponse,
|
||||||
|
generate,
|
||||||
|
load,
|
||||||
|
make_sampler,
|
||||||
|
stream_generate,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestGenerate(unittest.TestCase):
|
class TestGenerate(unittest.TestCase):
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def setUpClass(cls):
|
def setUpClass(cls):
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
|
||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
# Simple test that generation runs
|
# Simple test that generation runs
|
||||||
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||||
|
|
||||||
|
def test_stream_generate_speculative(self):
|
||||||
|
# Use same model as draft model, this is not a speed test
|
||||||
|
draft_model, _ = load(self.HF_MODEL_PATH)
|
||||||
|
|
||||||
|
results: List[GenerationResponse] = []
|
||||||
|
drafted: List[bool] = []
|
||||||
|
|
||||||
|
# make a determinate sampler
|
||||||
|
sampler = make_sampler(temp=0.0)
|
||||||
|
|
||||||
|
for generation_result in stream_generate(
|
||||||
|
model=self.model,
|
||||||
|
tokenizer=self.tokenizer,
|
||||||
|
prompt="hello",
|
||||||
|
max_tokens=5,
|
||||||
|
draft_model=draft_model,
|
||||||
|
num_draft_tokens=2,
|
||||||
|
sampler=sampler,
|
||||||
|
):
|
||||||
|
drafted.append(generation_result.from_draft)
|
||||||
|
results.append(generation_result)
|
||||||
|
|
||||||
|
self.assertEqual(len(results), 5)
|
||||||
|
# since num_draft_tokens is 2 and draft model is the same, the
|
||||||
|
# first 2 generations should be drafts, the third should come
|
||||||
|
# from the target model, and last two should be drafts
|
||||||
|
self.assertEqual(drafted, [True, True, False, True, True])
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user