mlx-examples/llms/tests/test_utils.py

94 lines
2.9 KiB
Python
Raw Normal View History

# Copyright © 2024 Apple Inc.
import os
import tempfile
import unittest
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
from mlx_lm import utils
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestUtils(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir_fid.name)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_load(self):
model, _ = utils.load(HF_MODEL_PATH)
model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True)
mx.eval(model_lazy.parameters())
p1 = model.layers[0].mlp.up_proj.weight
p2 = model_lazy.layers[0].mlp.up_proj.weight
self.assertTrue(mx.allclose(p1, p2))
def test_make_shards(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=2048,
num_hidden_layers=32,
intermediate_size=4096,
num_attention_heads=32,
rms_norm_eps=1e-5,
vocab_size=30_000,
)
model = llama.Model(args)
weights = tree_flatten(model.parameters())
gb = sum(p.nbytes for _, p in weights) // 2**30
shards = utils.make_shards(dict(weights), 1)
self.assertTrue(gb <= len(shards) <= gb + 1)
def test_quantize(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = llama.Model(args)
weights, config = utils.quantize_model(model, {}, 64, 4)
self.assertTrue("model.layers.2.mlp.up_proj.scales" in weights)
self.assertTrue("model.layers.2.mlp.up_proj.biases" in weights)
self.assertEqual(config["quantization"]["group_size"], 64)
self.assertEqual(config["quantization"]["bits"], 4)
def test_convert(self):
mlx_path = os.path.join(self.test_dir, "mlx_model")
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, quantize=True)
model, _ = utils.load(mlx_path)
self.assertTrue(isinstance(model.layers[0].mlp.up_proj, nn.QuantizedLinear))
self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear))
# Check model weights have right type
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16")
model, _ = utils.load(mlx_path)
self.assertEqual(model.layers[0].mlp.up_proj.weight.dtype, mx.bfloat16)
self.assertEqual(model.layers[-1].mlp.up_proj.weight.dtype, mx.bfloat16)
if __name__ == "__main__":
unittest.main()