mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
282 lines
8.5 KiB
Python
282 lines
8.5 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import unittest
|
|
|
|
import mlx.core as mx
|
|
from mlx.utils import tree_map
|
|
|
|
|
|
class TestModels(unittest.TestCase):
|
|
|
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
|
|
|
self.assertEqual(len(model.layers), num_layers)
|
|
self.assertEqual(model.model_type, model_type)
|
|
|
|
for t in [mx.float32, mx.float16]:
|
|
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
|
|
|
|
inputs = mx.array([[0, 1]])
|
|
outputs, cache = model(inputs)
|
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
|
self.assertEqual(outputs.dtype, t)
|
|
|
|
outputs, cache = model(mx.argmax(outputs[1, :], keepdims=True), cache=cache)
|
|
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
|
self.assertEqual(outputs.dtype, t)
|
|
|
|
def test_llama(self):
|
|
from mlx_lm.models import llama
|
|
|
|
args = llama.ModelArgs(
|
|
model_type="llama",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
rms_norm_eps=1e-5,
|
|
vocab_size=10_000,
|
|
)
|
|
model = llama.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
def test_phi2(self):
|
|
from mlx_lm.models import phi
|
|
|
|
args = phi.ModelArgs()
|
|
model = phi.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
def test_gemma(self):
|
|
from mlx_lm.models import gemma
|
|
|
|
args = gemma.ModelArgs(
|
|
model_type="gemma",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
head_dim=128,
|
|
rms_norm_eps=1e-5,
|
|
vocab_size=10_000,
|
|
num_key_value_heads=4,
|
|
)
|
|
model = gemma.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
def test_mixtral(self):
|
|
from mlx_lm.models import mixtral
|
|
|
|
# Make a baby mixtral, because it will actually do the
|
|
# eval
|
|
args = mixtral.ModelArgs(
|
|
model_type="mixtral",
|
|
vocab_size=100,
|
|
hidden_size=32,
|
|
intermediate_size=128,
|
|
num_hidden_layers=2,
|
|
num_attention_heads=4,
|
|
num_experts_per_tok=2,
|
|
num_key_value_heads=2,
|
|
num_local_experts=4,
|
|
)
|
|
model = mixtral.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
@unittest.skip("requires ai2-olmo")
|
|
def test_olmo(self):
|
|
from mlx_lm.models import olmo
|
|
|
|
args = olmo.ModelArgs(
|
|
model_type="olmo",
|
|
d_model=1024,
|
|
n_layers=4,
|
|
mlp_hidden_size=2048,
|
|
n_heads=2,
|
|
vocab_size=10_000,
|
|
embedding_size=10_000,
|
|
)
|
|
model = olmo.Model(args)
|
|
self.model_test_runner(
|
|
model,
|
|
args.model_type,
|
|
args.vocab_size,
|
|
args.n_layers,
|
|
)
|
|
|
|
def test_qwen2(self):
|
|
from mlx_lm.models import qwen2
|
|
|
|
args = qwen2.ModelArgs(
|
|
model_type="qwen2",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
rms_norm_eps=1e-5,
|
|
vocab_size=10_000,
|
|
)
|
|
model = qwen2.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
def test_qwen2_tie_word_embeddings_without_lm_head_weight(self):
|
|
from mlx_lm.models import qwen2
|
|
|
|
args = qwen2.ModelArgs(
|
|
model_type="qwen2",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
rms_norm_eps=1e-5,
|
|
vocab_size=10_000,
|
|
tie_word_embeddings=True,
|
|
)
|
|
model = qwen2.Model(args)
|
|
weights = {"model.embed_tokens.weight": "some_value"}
|
|
sanitized_weights = model.sanitize(weights)
|
|
self.assertIn("lm_head.weight", sanitized_weights)
|
|
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
|
|
|
def test_qwen2_tie_word_embeddings_with_lm_head_weight(self):
|
|
from mlx_lm.models import qwen2
|
|
|
|
weights = {
|
|
"model.embed_tokens.weight": "some_value",
|
|
"lm_head.weight": "existing_value",
|
|
}
|
|
args = qwen2.ModelArgs(
|
|
model_type="qwen2",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
rms_norm_eps=1e-5,
|
|
vocab_size=10_000,
|
|
tie_word_embeddings=True,
|
|
)
|
|
model = qwen2.Model(args)
|
|
sanitized_weights = model.sanitize(weights)
|
|
self.assertIn("lm_head.weight", sanitized_weights)
|
|
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
|
|
|
def test_qwen(self):
|
|
from mlx_lm.models import qwen
|
|
|
|
args = qwen.ModelArgs(
|
|
model_type="qwen",
|
|
)
|
|
model = qwen.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
def test_plamo(self):
|
|
from mlx_lm.models import plamo
|
|
|
|
args = plamo.ModelArgs(
|
|
model_type="plamo",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=8,
|
|
rms_norm_eps=1e-5,
|
|
vocab_size=10_000,
|
|
)
|
|
model = plamo.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
def test_stablelm(self):
|
|
from mlx_lm.models import stablelm
|
|
|
|
args = stablelm.ModelArgs(
|
|
model_type="stablelm",
|
|
vocab_size=10_000,
|
|
hidden_size=1024,
|
|
num_attention_heads=4,
|
|
num_hidden_layers=4,
|
|
num_key_value_heads=2,
|
|
partial_rotary_factor=1.0,
|
|
intermediate_size=2048,
|
|
layer_norm_eps=1e-2,
|
|
rope_theta=10_000,
|
|
use_qkv_bias=False,
|
|
)
|
|
model = stablelm.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
def test_starcoder2(self):
|
|
from mlx_lm.models import starcoder2
|
|
|
|
args = starcoder2.ModelArgs(
|
|
model_type="starcoder2",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
num_key_value_heads=4,
|
|
)
|
|
model = starcoder2.Model(args)
|
|
self.model_test_runner(
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
)
|
|
|
|
def test_starcoder2_tie_word_embeddings_without_lm_head_weight(self):
|
|
from mlx_lm.models import starcoder2
|
|
|
|
args = starcoder2.ModelArgs(
|
|
model_type="starcoder2",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
num_key_value_heads=4,
|
|
tie_word_embeddings=True,
|
|
)
|
|
model = starcoder2.Model(args)
|
|
weights = {"model.embed_tokens.weight": "some_value"}
|
|
sanitized_weights = model.sanitize(weights)
|
|
self.assertIn("lm_head.weight", sanitized_weights)
|
|
self.assertEqual(sanitized_weights["lm_head.weight"], "some_value")
|
|
|
|
def test_starcoder2_tie_word_embeddings_with_lm_head_weight(self):
|
|
|
|
from mlx_lm.models import starcoder2
|
|
|
|
args = starcoder2.ModelArgs(
|
|
model_type="starcoder2",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
num_key_value_heads=4,
|
|
tie_word_embeddings=True,
|
|
)
|
|
model = starcoder2.Model(args)
|
|
weights = {
|
|
"model.embed_tokens.weight": "some_value",
|
|
"lm_head.weight": "existing_value",
|
|
}
|
|
|
|
sanitized_weights = model.sanitize(weights)
|
|
self.assertIn("lm_head.weight", sanitized_weights)
|
|
self.assertEqual(sanitized_weights["lm_head.weight"], "existing_value")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|