mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Add logits_processor option to generate_step function (#983)
* Add logits_processor option for the generation as in huggingface transformers library * concatenation correction * Rename the tokens variable for clarity * remove the logit_bias argument from generate_step method * fix the variable name * nits + test * test * add back logit bias + test --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
55
llms/tests/test_generate.py
Normal file
55
llms/tests/test_generate.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
from mlx_lm.utils import generate, load
|
||||
|
||||
|
||||
class TestGenerate(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
def test_generate(self):
|
||||
# Simple test that generation runs
|
||||
text = generate(
|
||||
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
|
||||
)
|
||||
|
||||
def test_generate_with_logit_bias(self):
|
||||
logit_bias = {0: 2000.0, 1: -20.0}
|
||||
text = generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
"hello",
|
||||
max_tokens=5,
|
||||
verbose=False,
|
||||
logit_bias=logit_bias,
|
||||
)
|
||||
self.assertEqual(text, "!!!!!")
|
||||
|
||||
def test_generate_with_processor(self):
|
||||
init_toks = self.tokenizer.encode("hello")
|
||||
|
||||
all_toks = None
|
||||
|
||||
def logits_processor(toks, logits):
|
||||
nonlocal all_toks
|
||||
all_toks = toks
|
||||
return logits
|
||||
|
||||
generate(
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
"hello",
|
||||
max_tokens=5,
|
||||
verbose=False,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user