mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-06 08:54:33 +08:00
test
This commit is contained in:
@@ -7,18 +7,19 @@ from mlx_lm.utils import generate, load
|
|||||||
|
|
||||||
class TestGenerate(unittest.TestCase):
|
class TestGenerate(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
|
||||||
|
|
||||||
def test_generate(self):
|
def test_generate(self):
|
||||||
# Simple test that generation runs
|
# Simple test that generation runs
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
text = generate(
|
||||||
model, tokenizer = load(HF_MODEL_PATH)
|
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
|
||||||
text = generate(model, tokenizer, "hello", max_tokens=5, verbose=False)
|
)
|
||||||
|
|
||||||
def test_generate_with_processor(self):
|
def test_generate_with_processor(self):
|
||||||
# Simple test that generation runs
|
init_toks = self.tokenizer.encode("hello")
|
||||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
|
||||||
model, tokenizer = load(HF_MODEL_PATH)
|
|
||||||
|
|
||||||
init_toks = tokenizer.encode("hello")
|
|
||||||
|
|
||||||
all_toks = None
|
all_toks = None
|
||||||
|
|
||||||
@@ -28,8 +29,8 @@ class TestGenerate(unittest.TestCase):
|
|||||||
return logits
|
return logits
|
||||||
|
|
||||||
generate(
|
generate(
|
||||||
model,
|
self.model,
|
||||||
tokenizer,
|
self.tokenizer,
|
||||||
"hello",
|
"hello",
|
||||||
max_tokens=5,
|
max_tokens=5,
|
||||||
verbose=False,
|
verbose=False,
|
||||||
|
Reference in New Issue
Block a user