This commit is contained in:
Awni Hannun
2024-09-28 08:32:09 -07:00
parent 824f7fda58
commit c8216caa61

View File

@@ -7,18 +7,19 @@ from mlx_lm.utils import generate, load
class TestGenerate(unittest.TestCase): class TestGenerate(unittest.TestCase):
@classmethod
def setUpClass(cls):
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
cls.model, cls.tokenizer = load(HF_MODEL_PATH)
def test_generate(self): def test_generate(self):
# Simple test that generation runs # Simple test that generation runs
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" text = generate(
model, tokenizer = load(HF_MODEL_PATH) self.model, self.tokenizer, "hello", max_tokens=5, verbose=False
text = generate(model, tokenizer, "hello", max_tokens=5, verbose=False) )
def test_generate_with_processor(self): def test_generate_with_processor(self):
# Simple test that generation runs init_toks = self.tokenizer.encode("hello")
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
model, tokenizer = load(HF_MODEL_PATH)
init_toks = tokenizer.encode("hello")
all_toks = None all_toks = None
@@ -28,8 +29,8 @@ class TestGenerate(unittest.TestCase):
return logits return logits
generate( generate(
model, self.model,
tokenizer, self.tokenizer,
"hello", "hello",
max_tokens=5, max_tokens=5,
verbose=False, verbose=False,