mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Add "from_draft" to GenerationResponse (#1272)
* Add from_draft field in GenerationResponse * Cleanup * Re-work for minimal changes, add test * Fix comment
This commit is contained in:
parent
bded1a8fcd
commit
3d677f0870
@ -13,7 +13,18 @@ import time
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from textwrap import dedent
|
||||
from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generator,
|
||||
List,
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
)
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -65,6 +76,7 @@ class GenerationResponse:
|
||||
Args:
|
||||
text (str): The next segment of decoded text. This can be an empty string.
|
||||
token (int): The next token.
|
||||
from_draft (bool): Whether the token was generated by the draft model.
|
||||
logprobs (mx.array): A vector of log probabilities.
|
||||
prompt_tokens (int): The number of tokens in the prompt.
|
||||
prompt_tps (float): The prompt processing tokens-per-second.
|
||||
@ -77,6 +89,7 @@ class GenerationResponse:
|
||||
text: str
|
||||
token: int
|
||||
logprobs: mx.array
|
||||
from_draft: bool
|
||||
prompt_tokens: int
|
||||
prompt_tps: float
|
||||
generation_tokens: int
|
||||
@ -338,7 +351,7 @@ def speculative_generate_step(
|
||||
kv_bits: Optional[int] = None,
|
||||
kv_group_size: int = 64,
|
||||
quantized_kv_start: int = 0,
|
||||
) -> Generator[Tuple[mx.array, mx.array], None, None]:
|
||||
) -> Generator[Tuple[mx.array, mx.array, bool], None, None]:
|
||||
"""
|
||||
A generator producing token ids based on the given prompt from the model.
|
||||
|
||||
@ -365,7 +378,8 @@ def speculative_generate_step(
|
||||
when ``kv_bits`` is non-None. Default: ``0``.
|
||||
|
||||
Yields:
|
||||
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
|
||||
Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities,
|
||||
and a bool indicating if the token was generated by the draft model
|
||||
"""
|
||||
|
||||
y = prompt
|
||||
@ -450,12 +464,12 @@ def speculative_generate_step(
|
||||
break
|
||||
n += 1
|
||||
ntoks += 1
|
||||
yield tn, lpn
|
||||
yield tn, lpn, True
|
||||
if ntoks == max_tokens:
|
||||
break
|
||||
if ntoks < max_tokens:
|
||||
ntoks += 1
|
||||
yield tokens[n], logprobs[n]
|
||||
yield tokens[n], logprobs[n], False
|
||||
|
||||
if ntoks == max_tokens:
|
||||
break
|
||||
@ -463,7 +477,7 @@ def speculative_generate_step(
|
||||
y = mx.array([tokens[n]], mx.uint32)
|
||||
draft_y = y
|
||||
|
||||
# If we accpeted all the draft tokens, include the last
|
||||
# If we accepted all the draft tokens, include the last
|
||||
# draft token in the next draft step since it hasn't been
|
||||
# processed yet by the draft model
|
||||
if n == num_draft:
|
||||
@ -518,6 +532,10 @@ def stream_generate(
|
||||
if draft_model is None:
|
||||
kwargs.pop("num_draft_tokens", None)
|
||||
token_generator = generate_step(prompt, model, **kwargs)
|
||||
# from_draft always false for non-speculative generation
|
||||
token_generator = (
|
||||
(token, logprobs, False) for token, logprobs in token_generator
|
||||
)
|
||||
else:
|
||||
kwargs.pop("max_kv_size", None)
|
||||
token_generator = speculative_generate_step(
|
||||
@ -526,7 +544,7 @@ def stream_generate(
|
||||
with wired_limit(model, [generation_stream]):
|
||||
detokenizer.reset()
|
||||
tic = time.perf_counter()
|
||||
for n, (token, logprobs) in enumerate(token_generator):
|
||||
for n, (token, logprobs, from_draft) in enumerate(token_generator):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
prompt_tps = prompt.size / prompt_time
|
||||
@ -540,6 +558,7 @@ def stream_generate(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
@ -553,6 +572,7 @@ def stream_generate(
|
||||
text=detokenizer.last_segment,
|
||||
token=token,
|
||||
logprobs=logprobs,
|
||||
from_draft=from_draft,
|
||||
prompt_tokens=prompt.size,
|
||||
prompt_tps=prompt_tps,
|
||||
generation_tokens=n + 1,
|
||||
|
@ -1,17 +1,24 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from mlx_lm.sample_utils import make_logits_processors
|
||||
from mlx_lm.utils import generate, load
|
||||
from mlx_lm.utils import (
|
||||
GenerationResponse,
|
||||
generate,
|
||||
load,
|
||||
make_sampler,
|
||||
stream_generate,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerate(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||
cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
|
||||
|
||||
def test_generate(self):
|
||||
# Simple test that generation runs
|
||||
@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||
|
||||
def test_stream_generate_speculative(self):
|
||||
# Use same model as draft model, this is not a speed test
|
||||
draft_model, _ = load(self.HF_MODEL_PATH)
|
||||
|
||||
results: List[GenerationResponse] = []
|
||||
drafted: List[bool] = []
|
||||
|
||||
# make a determinate sampler
|
||||
sampler = make_sampler(temp=0.0)
|
||||
|
||||
for generation_result in stream_generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt="hello",
|
||||
max_tokens=5,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=2,
|
||||
sampler=sampler,
|
||||
):
|
||||
drafted.append(generation_result.from_draft)
|
||||
results.append(generation_result)
|
||||
|
||||
self.assertEqual(len(results), 5)
|
||||
# since num_draft_tokens is 2 and draft model is the same, the
|
||||
# first 2 generations should be drafts, the third should come
|
||||
# from the target model, and last two should be drafts
|
||||
self.assertEqual(drafted, [True, True, False, True, True])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user