nits + test

This commit is contained in:
Awni Hannun
2024-09-28 08:28:37 -07:00
parent 50e4665c1d
commit 824f7fda58
2 changed files with 46 additions and 4 deletions

View File

@@ -212,7 +212,7 @@ def generate_step(
)
y = prompt
tokens = prompt
tokens = None
# Create the KV cache for generation
cache = make_kv_caches(model, max_kv_size)
@@ -233,11 +233,13 @@ def generate_step(
repetition_context = repetition_context[-repetition_context_size:]
def _step(y):
nonlocal repetition_context, tokens
nonlocal repetition_context
logits = model(y[None], cache=cache)
logits = logits[:, -1, :]
if logits_processor:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
logits = logits_processor(tokens, logits)
if repetition_penalty:
@@ -249,8 +251,6 @@ def generate_step(
else:
y, logprobs = sample(logits)
tokens = mx.concat([tokens, y], axis=0)
if repetition_context_size:
if len(repetition_context) > repetition_context_size:
repetition_context = repetition_context[-repetition_context_size:]

View File

@@ -0,0 +1,42 @@
# Copyright © 2024 Apple Inc.
import unittest
from mlx_lm.utils import generate, load
class TestGenerate(unittest.TestCase):
def test_generate(self):
# Simple test that generation runs
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
model, tokenizer = load(HF_MODEL_PATH)
text = generate(model, tokenizer, "hello", max_tokens=5, verbose=False)
def test_generate_with_processor(self):
# Simple test that generation runs
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
model, tokenizer = load(HF_MODEL_PATH)
init_toks = tokenizer.encode("hello")
all_toks = None
def logits_processor(toks, logits):
nonlocal all_toks
all_toks = toks
return logits
generate(
model,
tokenizer,
"hello",
max_tokens=5,
verbose=False,
logits_processor=logits_processor,
)
self.assertEqual(len(all_toks), len(init_toks) + 5)
if __name__ == "__main__":
unittest.main()