# Copyright © 2024 Apple Inc. import os import tempfile import unittest import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten from mlx_lm import utils HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" class TestUtils(unittest.TestCase): @classmethod def setUpClass(cls): cls.test_dir_fid = tempfile.TemporaryDirectory() cls.test_dir = cls.test_dir_fid.name if not os.path.isdir(cls.test_dir): os.mkdir(cls.test_dir_fid.name) @classmethod def tearDownClass(cls): cls.test_dir_fid.cleanup() def test_load(self): model, _ = utils.load(HF_MODEL_PATH) model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True) mx.eval(model_lazy.parameters()) p1 = model.layers[0].mlp.up_proj.weight p2 = model_lazy.layers[0].mlp.up_proj.weight self.assertTrue(mx.allclose(p1, p2)) def test_make_shards(self): from mlx_lm.models import llama args = llama.ModelArgs( model_type="llama", hidden_size=2048, num_hidden_layers=32, intermediate_size=4096, num_attention_heads=32, rms_norm_eps=1e-5, vocab_size=30_000, ) model = llama.Model(args) weights = tree_flatten(model.parameters()) gb = sum(p.nbytes for _, p in weights) // 2**30 shards = utils.make_shards(dict(weights), 1) self.assertTrue(gb <= len(shards) <= gb + 1) def test_quantize(self): from mlx_lm.models import llama args = llama.ModelArgs( model_type="llama", hidden_size=1024, num_hidden_layers=4, intermediate_size=2048, num_attention_heads=4, rms_norm_eps=1e-5, vocab_size=10_000, ) model = llama.Model(args) weights, config = utils.quantize_model(model, {}, 64, 4) self.assertTrue("model.layers.2.mlp.up_proj.scales" in weights) self.assertTrue("model.layers.2.mlp.up_proj.biases" in weights) self.assertEqual(config["quantization"]["group_size"], 64) self.assertEqual(config["quantization"]["bits"], 4) def test_convert(self): mlx_path = os.path.join(self.test_dir, "mlx_model") utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, quantize=True) model, _ = utils.load(mlx_path) self.assertTrue(isinstance(model.layers[0].mlp.up_proj, nn.QuantizedLinear)) self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear)) # Check model weights have right type utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16") model, _ = utils.load(mlx_path) self.assertEqual(model.layers[0].mlp.up_proj.weight.dtype, mx.bfloat16) self.assertEqual(model.layers[-1].mlp.up_proj.weight.dtype, mx.bfloat16) if __name__ == "__main__": unittest.main()