mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Add "from_draft" to GenerationResponse (#1272)
* Add from_draft field in GenerationResponse * Cleanup * Re-work for minimal changes, add test * Fix comment
This commit is contained in:
@@ -1,17 +1,24 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
from typing import List
|
||||
|
||||
from mlx_lm.sample_utils import make_logits_processors
|
||||
from mlx_lm.utils import generate, load
|
||||
from mlx_lm.utils import (
|
||||
GenerationResponse,
|
||||
generate,
|
||||
load,
|
||||
make_sampler,
|
||||
stream_generate,
|
||||
)
|
||||
|
||||
|
||||
class TestGenerate(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||
cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
|
||||
|
||||
def test_generate(self):
|
||||
# Simple test that generation runs
|
||||
@@ -51,6 +58,34 @@ class TestGenerate(unittest.TestCase):
|
||||
)
|
||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||
|
||||
def test_stream_generate_speculative(self):
|
||||
# Use same model as draft model, this is not a speed test
|
||||
draft_model, _ = load(self.HF_MODEL_PATH)
|
||||
|
||||
results: List[GenerationResponse] = []
|
||||
drafted: List[bool] = []
|
||||
|
||||
# make a determinate sampler
|
||||
sampler = make_sampler(temp=0.0)
|
||||
|
||||
for generation_result in stream_generate(
|
||||
model=self.model,
|
||||
tokenizer=self.tokenizer,
|
||||
prompt="hello",
|
||||
max_tokens=5,
|
||||
draft_model=draft_model,
|
||||
num_draft_tokens=2,
|
||||
sampler=sampler,
|
||||
):
|
||||
drafted.append(generation_result.from_draft)
|
||||
results.append(generation_result)
|
||||
|
||||
self.assertEqual(len(results), 5)
|
||||
# since num_draft_tokens is 2 and draft model is the same, the
|
||||
# first 2 generations should be drafts, the third should come
|
||||
# from the target model, and last two should be drafts
|
||||
self.assertEqual(drafted, [True, True, False, True, True])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user