Add logits_processor option to generate_step function (#983)

* Add logits_processor option for the generation as in huggingface transformers library

* concatenation correction

* Rename the tokens variable for clarity

* remove the logit_bias argument from generate_step method

* fix the variable name

* nits + test

* test

* add back logit bias + test

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
nathan
2024-09-28 19:08:49 +02:00
committed by GitHub
parent d812516d3d
commit ace2bb5890
2 changed files with 74 additions and 6 deletions

View File

@@ -0,0 +1,55 @@
# Copyright © 2024 Apple Inc.
import unittest
from mlx_lm.utils import generate, load
class TestGenerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
def test_generate(self):
# Simple test that generation runs
text = generate(
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
)
def test_generate_with_logit_bias(self):
logit_bias = {0: 2000.0, 1: -20.0}
text = generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logit_bias=logit_bias,
)
self.assertEqual(text, "!!!!!")
def test_generate_with_processor(self):
init_toks = self.tokenizer.encode("hello")
all_toks = None
def logits_processor(toks, logits):
nonlocal all_toks
all_toks = toks
return logits
generate(
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,
logits_processor=logits_processor,
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
if __name__ == "__main__":
unittest.main()