mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
nits + test
This commit is contained in:
@@ -212,7 +212,7 @@ def generate_step(
|
||||
)
|
||||
|
||||
y = prompt
|
||||
tokens = prompt
|
||||
tokens = None
|
||||
|
||||
# Create the KV cache for generation
|
||||
cache = make_kv_caches(model, max_kv_size)
|
||||
@@ -233,11 +233,13 @@ def generate_step(
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
|
||||
def _step(y):
|
||||
nonlocal repetition_context, tokens
|
||||
nonlocal repetition_context
|
||||
logits = model(y[None], cache=cache)
|
||||
logits = logits[:, -1, :]
|
||||
|
||||
if logits_processor:
|
||||
nonlocal tokens
|
||||
tokens = mx.concat([tokens, y]) if tokens is not None else y
|
||||
logits = logits_processor(tokens, logits)
|
||||
|
||||
if repetition_penalty:
|
||||
@@ -249,8 +251,6 @@ def generate_step(
|
||||
else:
|
||||
y, logprobs = sample(logits)
|
||||
|
||||
tokens = mx.concat([tokens, y], axis=0)
|
||||
|
||||
if repetition_context_size:
|
||||
if len(repetition_context) > repetition_context_size:
|
||||
repetition_context = repetition_context[-repetition_context_size:]
|
||||
|
42
llms/tests/test_generate.py
Normal file
42
llms/tests/test_generate.py
Normal file
@@ -0,0 +1,42 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
from mlx_lm.utils import generate, load
|
||||
|
||||
|
||||
class TestGenerate(unittest.TestCase):
|
||||
|
||||
def test_generate(self):
|
||||
# Simple test that generation runs
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
text = generate(model, tokenizer, "hello", max_tokens=5, verbose=False)
|
||||
|
||||
def test_generate_with_processor(self):
|
||||
# Simple test that generation runs
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
init_toks = tokenizer.encode("hello")
|
||||
|
||||
all_toks = None
|
||||
|
||||
def logits_processor(toks, logits):
|
||||
nonlocal all_toks
|
||||
all_toks = toks
|
||||
return logits
|
||||
|
||||
generate(
|
||||
model,
|
||||
tokenizer,
|
||||
"hello",
|
||||
max_tokens=5,
|
||||
verbose=False,
|
||||
logits_processor=logits_processor,
|
||||
)
|
||||
self.assertEqual(len(all_toks), len(init_toks) + 5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user