mlx-examples/llms/tests/test_models.py
Awni Hannun 004eb4cc9d
Tencent HunYuan MOE model (#1100)
* hunyuan

* fix

* format str

* default trust remote code for tokenizer, allow system prompt to be configurable
2024-11-23 11:06:26 -08:00

798 lines
24 KiB
Python

# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
class TestModels(unittest.TestCase):
def test_kv_cache(self):
cache = KVCache()
k = mx.ones((1, 4, 1, 32), mx.float16)
v = mx.ones((1, 4, 1, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
self.assertEqual(cache.offset, 1)
k = mx.ones((1, 4, cache.step, 32), mx.float16)
v = mx.ones((1, 4, cache.step, 32), mx.float16)
k_up, v_up = cache.update_and_fetch(k, v)
expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16)
self.assertTrue(mx.array_equal(k_up, expected))
self.assertTrue(mx.array_equal(v_up, expected))
self.assertEqual(cache.offset, cache.step + 1)
def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
cache = RotatingKVCache(max_size=8, step=4)
k = mx.random.uniform(shape=(b, h, 2, d))
v = mx.random.uniform(shape=(b, h, 2, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
self.assertEqual(cache.offset, 2)
k = mx.random.uniform(shape=(b, h, 5, d))
v = mx.random.uniform(shape=(b, h, 5, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
k = mx.random.uniform(shape=(b, h, 4, d))
v = mx.random.uniform(shape=(b, h, 4, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
idx = 0
for _ in range(10):
k = mx.random.uniform(shape=(b, h, 1, d))
v = mx.random.uniform(shape=(b, h, 1, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
idx += 1
idx %= 8
# Try with nonzero keep
cache = RotatingKVCache(max_size=8, step=4, keep=2)
# Check a large update
k = mx.random.uniform(shape=(b, h, 20, d))
v = mx.random.uniform(shape=(b, h, 20, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up, k))
self.assertTrue(mx.array_equal(v_up, v))
# A bunch of small updates
self.assertEqual(cache.offset, 20)
idx = 2
for i in range(10):
k = mx.random.uniform(shape=(b, h, 1, d))
v = mx.random.uniform(shape=(b, h, 1, d))
k_up, v_up = cache.update_and_fetch(k, v)
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
self.assertEqual(cache.offset, 21 + i)
idx += 1
if idx >= 8:
idx = 2
def test_rotating_kv_cache_chat_mode(self):
# Test that the rotating kv cache can handle
# alternating prompt/prefill with generation
d = 4
h = 2
cache = RotatingKVCache(max_size=18, step=4)
x = mx.random.uniform(shape=(1, h, 8, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 8)
self.assertEqual(cache.offset, 8)
x = mx.random.uniform(shape=(1, h, 1, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 9)
self.assertEqual(cache.offset, 9)
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 11)
self.assertEqual(cache.offset, 11)
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
x = mx.random.uniform(shape=(1, h, 3, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 14)
self.assertEqual(cache.offset, 14)
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
x = mx.random.uniform(shape=(1, h, 6, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 20)
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)
self.assertEqual(model.model_type, model_type)
for t in [mx.float32, mx.float16]:
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
inputs = mx.array([[0, 1]])
outputs = model(inputs)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
cache = make_prompt_cache(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)
def test_llama(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = llama.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phi2(self):
from mlx_lm.models import phi
args = phi.ModelArgs()
model = phi.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phixtral(self):
from mlx_lm.models import phixtral
args = phixtral.ModelArgs(
"phixtral", num_vocab=1000, num_layers=4, model_dim=1024
)
model = phixtral.Model(args)
self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers)
def test_phi3(self):
from mlx_lm.models import phi3
args = phi3.ModelArgs(
model_type="phi3",
hidden_size=3072,
num_hidden_layers=32,
intermediate_size=8192,
num_attention_heads=32,
rms_norm_eps=1e-5,
vocab_size=32064,
)
model = phi3.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma(self):
from mlx_lm.models import gemma
args = gemma.ModelArgs(
model_type="gemma",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
head_dim=128,
rms_norm_eps=1e-5,
vocab_size=10_000,
num_key_value_heads=4,
)
model = gemma.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_mixtral(self):
from mlx_lm.models import mixtral
# Make a baby mixtral, because it will actually do the
# eval
args = mixtral.ModelArgs(
model_type="mixtral",
vocab_size=100,
hidden_size=32,
intermediate_size=128,
num_hidden_layers=2,
num_attention_heads=4,
num_experts_per_tok=2,
num_key_value_heads=2,
num_local_experts=4,
)
model = mixtral.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
@unittest.skip("requires ai2-olmo")
def test_olmo(self):
from mlx_lm.models import olmo
args = olmo.ModelArgs(
model_type="olmo",
d_model=1024,
n_layers=4,
mlp_hidden_size=2048,
n_heads=2,
vocab_size=10_000,
embedding_size=10_000,
)
model = olmo.Model(args)
self.model_test_runner(
model,
args.model_type,
args.vocab_size,
args.n_layers,
)
def test_qwen2_moe(self):
from mlx_lm.models import qwen2_moe
args = qwen2_moe.ModelArgs(
model_type="qwen2_moe",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
num_experts_per_tok=4,
num_experts=16,
moe_intermediate_size=1024,
shared_expert_intermediate_size=2048,
)
model = qwen2_moe.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_qwen2(self):
from mlx_lm.models import qwen2
args = qwen2.ModelArgs(
model_type="qwen2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = qwen2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_qwen(self):
from mlx_lm.models import qwen
args = qwen.ModelArgs(
model_type="qwen",
)
model = qwen.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_plamo(self):
from mlx_lm.models import plamo
args = plamo.ModelArgs(
model_type="plamo",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=8,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = plamo.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_stablelm(self):
from mlx_lm.models import stablelm
args = stablelm.ModelArgs(
model_type="stablelm",
vocab_size=10_000,
hidden_size=1024,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
partial_rotary_factor=1.0,
intermediate_size=2048,
layer_norm_eps=1e-2,
rope_theta=10_000,
use_qkv_bias=False,
)
model = stablelm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
# StableLM 2
args = stablelm.ModelArgs(
model_type="stablelm",
vocab_size=10000,
hidden_size=512,
num_attention_heads=8,
num_hidden_layers=4,
num_key_value_heads=2,
partial_rotary_factor=0.25,
intermediate_size=1024,
layer_norm_eps=1e-5,
rope_theta=10000,
use_qkv_bias=True,
)
model = stablelm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_starcoder2(self):
from mlx_lm.models import starcoder2
args = starcoder2.ModelArgs(
model_type="starcoder2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
num_key_value_heads=4,
)
model = starcoder2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_cohere(self):
from mlx_lm.models import cohere
args = cohere.ModelArgs(
model_type="cohere",
)
model = cohere.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_dbrx(self):
from mlx_lm.models import dbrx
args = dbrx.ModelArgs(
model_type="dbrx",
d_model=1024,
ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2},
attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000},
n_layers=4,
n_heads=4,
vocab_size=10_000,
)
model = dbrx.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers)
def test_minicpm(self):
from mlx_lm.models import minicpm
args = minicpm.ModelArgs(
model_type="minicpm",
hidden_size=1024,
dim_model_base=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-4,
vocab_size=10000,
num_key_value_heads=2,
scale_depth=1.0,
scale_emb=1.0,
)
model = minicpm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_mamba(self):
from mlx_lm.models import mamba
args = mamba.ModelArgs(
model_type="mamba",
vocab_size=10000,
use_bias=False,
use_conv_bias=True,
conv_kernel=4,
hidden_size=768,
num_hidden_layers=24,
state_size=16,
intermediate_size=1536,
time_step_rank=48,
)
model = mamba.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt2(self):
from mlx_lm.models import gpt2
args = gpt2.ModelArgs(
model_type="gpt2",
n_ctx=1024,
n_embd=768,
n_head=12,
n_layer=12,
n_positions=1024,
layer_norm_epsilon=1e-5,
vocab_size=50256,
)
model = gpt2.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_gpt_neox(self):
from mlx_lm.models import gpt_neox
args = gpt_neox.ModelArgs(
model_type="gpt_neox",
max_position_embeddings=2048,
hidden_size=6144,
num_attention_heads=64,
num_hidden_layers=44,
layer_norm_eps=1e-5,
vocab_size=50432,
rotary_emb_base=10_000,
rotary_pct=0.25,
)
model = gpt_neox.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_openelm(self):
from mlx_lm.models import openelm
args = openelm.ModelArgs(
model_type="openelm",
ffn_dim_divisor=256,
ffn_multipliers=[
0.5,
0.73,
0.97,
1.2,
1.43,
1.67,
1.9,
2.13,
2.37,
2.6,
2.83,
3.07,
3.3,
3.53,
3.77,
4.0,
],
head_dim=64,
model_dim=1280,
normalize_qk_projections=True,
num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5],
num_query_heads=[
12,
12,
12,
12,
12,
16,
16,
16,
16,
16,
16,
16,
20,
20,
20,
20,
],
num_transformer_layers=16,
vocab_size=32000,
)
model = openelm.Model(args)
self.model_test_runner(
model,
args.model_type,
args.vocab_size,
len(args.ffn_multipliers),
)
def test_internlm2(self):
from mlx_lm.models import internlm2
args = internlm2.ModelArgs(
model_type="internlm2",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10000,
)
model = internlm2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_llama3_1(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
max_position_embeddings=128,
mlp_bias=False,
num_key_value_heads=2,
rope_scaling={
"factor": 8.0,
"low_freq_factor": 1.0,
"high_freq_factor": 4.0,
"original_max_position_embeddings": 8192,
"rope_type": "llama3",
},
)
model = llama.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek(self):
from mlx_lm.models import deepseek
args = deepseek.ModelArgs(
model_type="deepseek",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=4,
)
model = deepseek.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek_v2(self):
from mlx_lm.models import deepseek_v2
args = deepseek_v2.ModelArgs(
model_type="deepseek_v2",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self):
from mlx_lm.models import gemma2
args = gemma2.ModelArgs(
model_type="gemma2",
hidden_size=128,
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=2,
head_dim=32,
rms_norm_eps=1e-4,
vocab_size=1024,
num_key_value_heads=2,
)
model = gemma2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt_bigcode(self):
from mlx_lm.models import gpt_bigcode
args = gpt_bigcode.ModelArgs(
model_type="gpt_bigcode",
n_embd=128,
n_layer=128,
n_inner=256,
n_head=4,
n_positions=1000,
layer_norm_epsilon=1e-5,
vocab_size=1024,
)
model = gpt_bigcode.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_nemotron(self):
from mlx_lm.models import nemotron
args = nemotron.ModelArgs(
model_type="nemotron",
hidden_size=128,
hidden_act="gelu",
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=4,
norm_eps=1e-5,
vocab_size=1024,
num_key_value_heads=2,
)
model = nemotron.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phi3small(self):
from mlx_lm.models import phi3small
args = phi3small.ModelArgs(
model_type="phi3small",
hidden_size=128,
dense_attention_every_n_layers=2,
ff_intermediate_size=256,
gegelu_limit=1.0,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
layer_norm_epsilon=1e-4,
vocab_size=1000,
)
model = phi3small.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phimoe(self):
from mlx_lm.models import phimoe
args = phimoe.ModelArgs(
model_type="phimoe",
vocab_size=320,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
rope_scaling={
"long_factor": [1.0] * 16,
"long_mscale": 1.243163121016122,
"original_max_position_embeddings": 4096,
"short_factor": [1.0] * 16,
"short_mscale": 1.243163121016122,
"type": "longrope",
},
)
model = phimoe.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_recurrent_gemma(self):
from mlx_lm.models import recurrent_gemma
args = recurrent_gemma.ModelArgs(
model_type="recurrent_gemma",
hidden_size=128,
attention_bias=False,
conv1d_width=3,
intermediate_size=256,
logits_soft_cap=1.0,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
attention_window_size=1024,
vocab_size=1000,
block_types=["recurrent", "recurrent", "attention"],
)
model = recurrent_gemma.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_hunyuan(self):
from mlx_lm.models import hunyuan
args = hunyuan.ModelArgs(
model_type="hunyuan",
hidden_size=128,
attention_bias=False,
intermediate_size=256,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
vocab_size=1000,
moe_topk=2,
num_experts=2,
num_shared_expert=1,
use_mixed_mlp_moe=True,
use_qk_norm=True,
rope_scaling={
"alpha": 1000.0,
"factor": 1.0,
"type": "dynamic",
},
use_cla=True,
cla_share_factor=2,
)
model = hunyuan.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()