# Copyright © 2024 Apple Inc. import unittest from typing import List from mlx_lm.sample_utils import make_logits_processors from mlx_lm.utils import ( GenerationResponse, generate, load, make_sampler, stream_generate, ) class TestGenerate(unittest.TestCase): @classmethod def setUpClass(cls): cls.HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" cls.model, cls.tokenizer = load(cls.HF_MODEL_PATH) def test_generate(self): # Simple test that generation runs text = generate( self.model, self.tokenizer, "hello", max_tokens=5, verbose=False ) def test_generate_with_logit_bias(self): logit_bias = {0: 2000.0, 1: -20.0} text = generate( self.model, self.tokenizer, "hello", max_tokens=5, logits_processors=make_logits_processors(logit_bias), verbose=False, ) self.assertEqual(text, "!!!!!") def test_generate_with_processor(self): init_toks = self.tokenizer.encode("hello") all_toks = None def logits_processor(toks, logits): nonlocal all_toks all_toks = toks return logits generate( self.model, self.tokenizer, "hello", max_tokens=5, verbose=False, logits_processors=[logits_processor], ) self.assertEqual(len(all_toks), len(init_toks) + 5) def test_stream_generate_speculative(self): # Use same model as draft model, this is not a speed test draft_model, _ = load(self.HF_MODEL_PATH) results: List[GenerationResponse] = [] drafted: List[bool] = [] # make a determinate sampler sampler = make_sampler(temp=0.0) for generation_result in stream_generate( model=self.model, tokenizer=self.tokenizer, prompt="hello", max_tokens=5, draft_model=draft_model, num_draft_tokens=2, sampler=sampler, ): drafted.append(generation_result.from_draft) results.append(generation_result) self.assertEqual(len(results), 5) # since num_draft_tokens is 2 and draft model is the same, the # first 2 generations should be drafts, the third should come # from the target model, and last two should be drafts self.assertEqual(drafted, [True, True, False, True, True]) if __name__ == "__main__": unittest.main()