mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
test
This commit is contained in:
@@ -7,18 +7,19 @@ from mlx_lm.utils import generate, load
|
||||
|
||||
class TestGenerate(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
def test_generate(self):
|
||||
# Simple test that generation runs
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
text = generate(model, tokenizer, "hello", max_tokens=5, verbose=False)
|
||||
text = generate(
|
||||
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
|
||||
)
|
||||
|
||||
def test_generate_with_processor(self):
|
||||
# Simple test that generation runs
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
model, tokenizer = load(HF_MODEL_PATH)
|
||||
|
||||
init_toks = tokenizer.encode("hello")
|
||||
init_toks = self.tokenizer.encode("hello")
|
||||
|
||||
all_toks = None
|
||||
|
||||
@@ -28,8 +29,8 @@ class TestGenerate(unittest.TestCase):
|
||||
return logits
|
||||
|
||||
generate(
|
||||
model,
|
||||
tokenizer,
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
"hello",
|
||||
max_tokens=5,
|
||||
verbose=False,
|
||||
|
Reference in New Issue
Block a user