mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
Enable unit testing in Circle and start some MLX LM tests (#545)
* add a few tests for mlx lm * add a few tests for mlx lm * add a few tests for mlx lm * more tests / cleanup
This commit is contained in:
199
llms/tests/test_models.py
Normal file
199
llms/tests/test_models.py
Normal file
@@ -0,0 +1,199 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
|
||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||
|
||||
self.assertEqual(len(model.layers), num_layers)
|
||||
self.assertEqual(model.model_type, model_type)
|
||||
|
||||
for t in [mx.float32, mx.float16]:
|
||||
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
|
||||
|
||||
inputs = mx.array([[0, 1]])
|
||||
outputs, cache = model(inputs)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
outputs, cache = model(mx.argmax(outputs[1, :], keepdims=True), cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
def test_llama(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = llama.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phi2(self):
|
||||
from mlx_lm.models import phi
|
||||
|
||||
args = phi.ModelArgs()
|
||||
model = phi.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma(self):
|
||||
from mlx_lm.models import gemma
|
||||
|
||||
args = gemma.ModelArgs(
|
||||
model_type="gemma",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
head_dim=128,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
model = gemma.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_mixtral(self):
|
||||
from mlx_lm.models import mixtral
|
||||
|
||||
# Make a baby mixtral, because it will actually do the
|
||||
# eval
|
||||
args = mixtral.ModelArgs(
|
||||
model_type="mixtral",
|
||||
vocab_size=100,
|
||||
hidden_size=32,
|
||||
intermediate_size=128,
|
||||
num_hidden_layers=2,
|
||||
num_attention_heads=4,
|
||||
num_experts_per_tok=2,
|
||||
num_key_value_heads=2,
|
||||
num_local_experts=4,
|
||||
)
|
||||
model = mixtral.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
@unittest.skip("requires ai2-olmo")
|
||||
def test_olmo(self):
|
||||
from mlx_lm.models import olmo
|
||||
|
||||
args = olmo.ModelArgs(
|
||||
model_type="olmo",
|
||||
d_model=1024,
|
||||
n_layers=4,
|
||||
mlp_hidden_size=2048,
|
||||
n_heads=2,
|
||||
vocab_size=10_000,
|
||||
embedding_size=10_000,
|
||||
)
|
||||
model = olmo.Model(args)
|
||||
self.model_test_runner(
|
||||
model,
|
||||
args.model_type,
|
||||
args.vocab_size,
|
||||
args.n_layers,
|
||||
)
|
||||
|
||||
def test_qwen2(self):
|
||||
from mlx_lm.models import qwen2
|
||||
|
||||
args = qwen2.ModelArgs(
|
||||
model_type="qwen2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = qwen2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_qwen(self):
|
||||
from mlx_lm.models import qwen
|
||||
|
||||
args = qwen.ModelArgs(
|
||||
model_type="qwen",
|
||||
)
|
||||
model = qwen.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_plamo(self):
|
||||
from mlx_lm.models import plamo
|
||||
|
||||
args = plamo.ModelArgs(
|
||||
model_type="plamo",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=8,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = plamo.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_stablelm(self):
|
||||
from mlx_lm.models import stablelm
|
||||
|
||||
args = stablelm.ModelArgs(
|
||||
model_type="stablelm",
|
||||
vocab_size=10_000,
|
||||
hidden_size=1024,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
partial_rotary_factor=1.0,
|
||||
intermediate_size=2048,
|
||||
layer_norm_eps=1e-2,
|
||||
rope_theta=10_000,
|
||||
use_qkv_bias=False,
|
||||
)
|
||||
model = stablelm.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_starcoder2(self):
|
||||
from mlx_lm.models import starcoder2
|
||||
|
||||
args = starcoder2.ModelArgs(
|
||||
model_type="starcoder2",
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
model = starcoder2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
45
llms/tests/test_utils.py
Normal file
45
llms/tests/test_utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import utils
|
||||
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
|
||||
class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_load(self):
|
||||
model, _ = utils.load(HF_MODEL_PATH)
|
||||
|
||||
model_lazy, _ = utils.load(HF_MODEL_PATH, lazy=True)
|
||||
|
||||
mx.eval(model_lazy.parameters())
|
||||
|
||||
p1 = model.layers[0].mlp.up_proj.weight
|
||||
p2 = model_lazy.layers[0].mlp.up_proj.weight
|
||||
self.assertTrue(mx.allclose(p1, p2))
|
||||
|
||||
def test_make_shards(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=2048,
|
||||
num_hidden_layers=32,
|
||||
intermediate_size=4096,
|
||||
num_attention_heads=32,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=30_000,
|
||||
)
|
||||
model = llama.Model(args)
|
||||
weights = tree_flatten(model.parameters())
|
||||
gb = sum(p.nbytes for _, p in weights) // 2**30
|
||||
shards = utils.make_shards(dict(weights), 1)
|
||||
self.assertTrue(gb <= len(shards) <= gb + 1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Reference in New Issue
Block a user