This commit is contained in:
Awni Hannun
2024-09-28 08:32:09 -07:00
parent 824f7fda58
commit c8216caa61

View File

@@ -7,18 +7,19 @@ from mlx_lm.utils import generate, load
class TestGenerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
def test_generate(self):
# Simple test that generation runs
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
model, tokenizer = load(HF_MODEL_PATH)
text = generate(model, tokenizer, "hello", max_tokens=5, verbose=False)
text = generate(
self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
)
def test_generate_with_processor(self):
# Simple test that generation runs
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
model, tokenizer = load(HF_MODEL_PATH)
init_toks = tokenizer.encode("hello")
init_toks = self.tokenizer.encode("hello")
all_toks = None
@@ -28,8 +29,8 @@ class TestGenerate(unittest.TestCase):
return logits
generate(
model,
tokenizer,
self.model,
self.tokenizer,
"hello",
max_tokens=5,
verbose=False,