diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b2e89a13..64813123 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -13,7 +13,18 @@ import time from dataclasses import dataclass from pathlib import Path from textwrap import dedent -from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, Union +from typing import ( + Any, + Callable, + Dict, + Generator, + List, + NamedTuple, + Optional, + Tuple, + Type, + Union, +) import mlx.core as mx import mlx.nn as nn @@ -65,6 +76,7 @@ class GenerationResponse: Args: text (str): The next segment of decoded text. This can be an empty string. token (int): The next token. + from_draft (bool): Whether the token was generated by the draft model. logprobs (mx.array): A vector of log probabilities. prompt_tokens (int): The number of tokens in the prompt. prompt_tps (float): The prompt processing tokens-per-second. @@ -77,6 +89,7 @@ class GenerationResponse: text: str token: int logprobs: mx.array + from_draft: bool prompt_tokens: int prompt_tps: float generation_tokens: int @@ -338,7 +351,7 @@ def speculative_generate_step( kv_bits: Optional[int] = None, kv_group_size: int = 64, quantized_kv_start: int = 0, -) -> Generator[Tuple[mx.array, mx.array], None, None]: +) -> Generator[Tuple[mx.array, mx.array, bool], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -365,7 +378,8 @@ def speculative_generate_step( when ``kv_bits`` is non-None. Default: ``0``. Yields: - Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + Tuple[mx.array, mx.array, bool]: One token, a vector of log probabilities, + and a bool indicating if the token was generated by the draft model """ y = prompt @@ -450,12 +464,12 @@ def speculative_generate_step( break n += 1 ntoks += 1 - yield tn, lpn + yield tn, lpn, True if ntoks == max_tokens: break if ntoks < max_tokens: ntoks += 1 - yield tokens[n], logprobs[n] + yield tokens[n], logprobs[n], False if ntoks == max_tokens: break @@ -463,7 +477,7 @@ def speculative_generate_step( y = mx.array([tokens[n]], mx.uint32) draft_y = y - # If we accpeted all the draft tokens, include the last + # If we accepted all the draft tokens, include the last # draft token in the next draft step since it hasn't been # processed yet by the draft model if n == num_draft: @@ -518,6 +532,10 @@ def stream_generate( if draft_model is None: kwargs.pop("num_draft_tokens", None) token_generator = generate_step(prompt, model, **kwargs) + # from_draft always false for non-speculative generation + token_generator = ( + (token, logprobs, False) for token, logprobs in token_generator + ) else: kwargs.pop("max_kv_size", None) token_generator = speculative_generate_step( @@ -526,7 +544,7 @@ def stream_generate( with wired_limit(model, [generation_stream]): detokenizer.reset() tic = time.perf_counter() - for n, (token, logprobs) in enumerate(token_generator): + for n, (token, logprobs, from_draft) in enumerate(token_generator): if n == 0: prompt_time = time.perf_counter() - tic prompt_tps = prompt.size / prompt_time @@ -540,6 +558,7 @@ def stream_generate( text=detokenizer.last_segment, token=token, logprobs=logprobs, + from_draft=from_draft, prompt_tokens=prompt.size, prompt_tps=prompt_tps, generation_tokens=n + 1, @@ -553,6 +572,7 @@ def stream_generate( text=detokenizer.last_segment, token=token, logprobs=logprobs, + from_draft=from_draft, prompt_tokens=prompt.size, prompt_tps=prompt_tps, generation_tokens=n + 1, diff --git a/llms/tests/test_generate.py b/llms/tests/test_generate.py index f2345394..7445a9b9 100644 --- a/llms/tests/test_generate.py +++ b/llms/tests/test_generate.py @@ -1,17 +1,24 @@ # Copyright © 2024 Apple Inc. import unittest +from typing import List from mlx_lm.sample_utils import make_logits_processors -from mlx_lm.utils import generate, load +from mlx_lm.utils import ( + GenerationResponse, + generate, + load, + make_sampler, + stream_generate, +) class TestGenerate(unittest.TestCase): @classmethod def setUpClass(cls): - HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" - cls.model, cls.tokenizer = load(HF_MODEL_PATH) + cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" + cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH) def test_generate(self): # Simple test that generation runs @@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase): ) self.assertEqual(len(all_toks), len(init_toks) + 5) + def test_stream_generate_speculative(self): + # Use same model as draft model, this is not a speed test + draft_model, _ = load(self.HF_MODEL_PATH) + + results: List[GenerationResponse] = [] + drafted: List[bool] = [] + + # make a determinate sampler + sampler = make_sampler(temp=0.0) + + for generation_result in stream_generate( + model=self.model, + tokenizer=self.tokenizer, + prompt="hello", + max_tokens=5, + draft_model=draft_model, + num_draft_tokens=2, + sampler=sampler, + ): + drafted.append(generation_result.from_draft) + results.append(generation_result) + + self.assertEqual(len(results), 5) + # since num_draft_tokens is 2 and draft model is the same, the + # first 2 generations should be drafts, the third should come + # from the target model, and last two should be drafts + self.assertEqual(drafted, [True, True, False, True, True]) + if __name__ == "__main__": unittest.main()