2024-09-29 01:08:49 +08:00
|
|
|
# Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
import unittest
|
2025-02-12 07:41:02 +08:00
|
|
|
from typing import List
|
2024-09-29 01:08:49 +08:00
|
|
|
|
2024-11-24 03:47:06 +08:00
|
|
|
from mlx_lm.sample_utils import make_logits_processors
|
2025-02-12 07:41:02 +08:00
|
|
|
from mlx_lm.utils import (
|
|
|
|
GenerationResponse,
|
|
|
|
generate,
|
|
|
|
load,
|
|
|
|
make_sampler,
|
|
|
|
stream_generate,
|
|
|
|
)
|
2024-09-29 01:08:49 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestGenerate(unittest.TestCase):
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def setUpClass(cls):
|
2025-02-12 07:41:02 +08:00
|
|
|
cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
|
|
|
cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH)
|
2024-09-29 01:08:49 +08:00
|
|
|
|
|
|
|
def test_generate(self):
|
|
|
|
# Simple test that generation runs
|
|
|
|
text = generate(
|
|
|
|
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_generate_with_logit_bias(self):
|
|
|
|
logit_bias = {0: 2000.0, 1: -20.0}
|
|
|
|
text = generate(
|
|
|
|
self.model,
|
|
|
|
self.tokenizer,
|
|
|
|
"hello",
|
|
|
|
max_tokens=5,
|
2024-11-24 03:47:06 +08:00
|
|
|
logits_processors=make_logits_processors(logit_bias),
|
2024-09-29 01:08:49 +08:00
|
|
|
verbose=False,
|
|
|
|
)
|
|
|
|
self.assertEqual(text, "!!!!!")
|
|
|
|
|
|
|
|
def test_generate_with_processor(self):
|
|
|
|
init_toks = self.tokenizer.encode("hello")
|
|
|
|
|
|
|
|
all_toks = None
|
|
|
|
|
|
|
|
def logits_processor(toks, logits):
|
|
|
|
nonlocal all_toks
|
|
|
|
all_toks = toks
|
|
|
|
return logits
|
|
|
|
|
|
|
|
generate(
|
|
|
|
self.model,
|
|
|
|
self.tokenizer,
|
|
|
|
"hello",
|
|
|
|
max_tokens=5,
|
|
|
|
verbose=False,
|
2024-11-08 08:15:24 +08:00
|
|
|
logits_processors=[logits_processor],
|
2024-09-29 01:08:49 +08:00
|
|
|
)
|
|
|
|
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
|
|
|
|
2025-02-12 07:41:02 +08:00
|
|
|
def test_stream_generate_speculative(self):
|
|
|
|
# Use same model as draft model, this is not a speed test
|
|
|
|
draft_model, _ = load(self.HF_MODEL_PATH)
|
|
|
|
|
|
|
|
results: List[GenerationResponse] = []
|
|
|
|
drafted: List[bool] = []
|
|
|
|
|
|
|
|
# make a determinate sampler
|
|
|
|
sampler = make_sampler(temp=0.0)
|
|
|
|
|
|
|
|
for generation_result in stream_generate(
|
|
|
|
model=self.model,
|
|
|
|
tokenizer=self.tokenizer,
|
|
|
|
prompt="hello",
|
|
|
|
max_tokens=5,
|
|
|
|
draft_model=draft_model,
|
|
|
|
num_draft_tokens=2,
|
|
|
|
sampler=sampler,
|
|
|
|
):
|
|
|
|
drafted.append(generation_result.from_draft)
|
|
|
|
results.append(generation_result)
|
|
|
|
|
|
|
|
self.assertEqual(len(results), 5)
|
|
|
|
# since num_draft_tokens is 2 and draft model is the same, the
|
|
|
|
# first 2 generations should be drafts, the third should come
|
|
|
|
# from the target model, and last two should be drafts
|
|
|
|
self.assertEqual(drafted, [True, True, False, True, True])
|
|
|
|
|
2024-09-29 01:08:49 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|