2024-03-08 01:31:57 +08:00
|
|
|
# Copyright © 2024 Apple Inc.
|
|
|
|
import unittest
|
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
from mlx.utils import tree_map
|
2024-10-08 11:45:51 +08:00
|
|
|
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
2024-03-08 01:31:57 +08:00
|
|
|
|
|
|
|
|
|
|
|
class TestModels(unittest.TestCase):
|
|
|
|
|
2024-05-16 03:56:24 +08:00
|
|
|
def test_kv_cache(self):
|
2024-10-08 11:45:51 +08:00
|
|
|
cache = KVCache()
|
2024-05-16 03:56:24 +08:00
|
|
|
|
|
|
|
k = mx.ones((1, 4, 1, 32), mx.float16)
|
|
|
|
v = mx.ones((1, 4, 1, 32), mx.float16)
|
|
|
|
|
|
|
|
k_up, v_up = cache.update_and_fetch(k, v)
|
|
|
|
self.assertTrue(mx.array_equal(k_up, k))
|
|
|
|
self.assertTrue(mx.array_equal(v_up, v))
|
|
|
|
self.assertEqual(cache.offset, 1)
|
|
|
|
|
|
|
|
k = mx.ones((1, 4, cache.step, 32), mx.float16)
|
|
|
|
v = mx.ones((1, 4, cache.step, 32), mx.float16)
|
|
|
|
k_up, v_up = cache.update_and_fetch(k, v)
|
|
|
|
|
|
|
|
expected = mx.ones((1, 4, cache.step + 1, 32), mx.float16)
|
|
|
|
self.assertTrue(mx.array_equal(k_up, expected))
|
|
|
|
self.assertTrue(mx.array_equal(v_up, expected))
|
|
|
|
self.assertEqual(cache.offset, cache.step + 1)
|
|
|
|
|
2024-08-17 06:28:39 +08:00
|
|
|
def test_rotating_kv_cache(self):
|
|
|
|
b, h, d = 1, 2, 32
|
2024-10-08 11:45:51 +08:00
|
|
|
cache = RotatingKVCache(max_size=8, step=4)
|
2024-08-17 06:28:39 +08:00
|
|
|
|
|
|
|
k = mx.random.uniform(shape=(b, h, 2, d))
|
|
|
|
v = mx.random.uniform(shape=(b, h, 2, d))
|
|
|
|
|
|
|
|
k_up, v_up = cache.update_and_fetch(k, v)
|
|
|
|
self.assertTrue(mx.array_equal(k_up, k))
|
|
|
|
self.assertTrue(mx.array_equal(v_up, v))
|
|
|
|
self.assertEqual(cache.offset, 2)
|
|
|
|
|
|
|
|
k = mx.random.uniform(shape=(b, h, 5, d))
|
|
|
|
v = mx.random.uniform(shape=(b, h, 5, d))
|
|
|
|
k_up, v_up = cache.update_and_fetch(k, v)
|
|
|
|
self.assertTrue(mx.array_equal(k_up[..., 2:, :], k))
|
|
|
|
self.assertTrue(mx.array_equal(v_up[..., 2:, :], v))
|
|
|
|
|
|
|
|
k = mx.random.uniform(shape=(b, h, 4, d))
|
|
|
|
v = mx.random.uniform(shape=(b, h, 4, d))
|
|
|
|
k_up, v_up = cache.update_and_fetch(k, v)
|
|
|
|
self.assertTrue(mx.array_equal(k_up[..., -4:, :], k))
|
|
|
|
self.assertTrue(mx.array_equal(v_up[..., -4:, :], v))
|
|
|
|
|
|
|
|
idx = 0
|
|
|
|
for _ in range(10):
|
|
|
|
k = mx.random.uniform(shape=(b, h, 1, d))
|
|
|
|
v = mx.random.uniform(shape=(b, h, 1, d))
|
|
|
|
k_up, v_up = cache.update_and_fetch(k, v)
|
|
|
|
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
|
|
|
|
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
|
|
|
|
idx += 1
|
|
|
|
idx %= 8
|
|
|
|
|
|
|
|
# Try with nonzero keep
|
2024-10-08 11:45:51 +08:00
|
|
|
cache = RotatingKVCache(max_size=8, step=4, keep=2)
|
2024-08-17 06:28:39 +08:00
|
|
|
|
|
|
|
# Check a large update
|
|
|
|
k = mx.random.uniform(shape=(b, h, 20, d))
|
|
|
|
v = mx.random.uniform(shape=(b, h, 20, d))
|
|
|
|
k_up, v_up = cache.update_and_fetch(k, v)
|
|
|
|
self.assertTrue(mx.array_equal(k_up, k))
|
|
|
|
self.assertTrue(mx.array_equal(v_up, v))
|
|
|
|
|
|
|
|
# A bunch of small updates
|
|
|
|
self.assertEqual(cache.offset, 20)
|
|
|
|
idx = 2
|
|
|
|
for i in range(10):
|
|
|
|
k = mx.random.uniform(shape=(b, h, 1, d))
|
|
|
|
v = mx.random.uniform(shape=(b, h, 1, d))
|
|
|
|
k_up, v_up = cache.update_and_fetch(k, v)
|
|
|
|
self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k))
|
|
|
|
self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v))
|
|
|
|
self.assertEqual(cache.offset, 21 + i)
|
|
|
|
idx += 1
|
|
|
|
if idx >= 8:
|
|
|
|
idx = 2
|
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
def test_rotating_kv_cache_chat_mode(self):
|
|
|
|
# Test that the rotating kv cache can handle
|
|
|
|
# alternating prompt/prefill with generation
|
|
|
|
d = 4
|
|
|
|
h = 2
|
|
|
|
cache = RotatingKVCache(max_size=18, step=4)
|
|
|
|
|
|
|
|
x = mx.random.uniform(shape=(1, h, 8, d))
|
|
|
|
k, v = cache.update_and_fetch(x, x)
|
|
|
|
self.assertEqual(k.shape[2], 8)
|
|
|
|
self.assertEqual(cache.offset, 8)
|
|
|
|
|
|
|
|
x = mx.random.uniform(shape=(1, h, 1, d))
|
|
|
|
k, v = cache.update_and_fetch(x, x)
|
|
|
|
self.assertEqual(k.shape[2], 9)
|
|
|
|
self.assertEqual(cache.offset, 9)
|
|
|
|
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
|
|
|
|
|
|
|
|
x = mx.random.uniform(shape=(1, h, 2, d))
|
|
|
|
k, v = cache.update_and_fetch(x, x)
|
|
|
|
self.assertEqual(k.shape[2], 11)
|
|
|
|
self.assertEqual(cache.offset, 11)
|
|
|
|
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
|
|
|
|
|
|
|
|
x = mx.random.uniform(shape=(1, h, 3, d))
|
|
|
|
k, v = cache.update_and_fetch(x, x)
|
|
|
|
self.assertEqual(k.shape[2], 14)
|
|
|
|
self.assertEqual(cache.offset, 14)
|
|
|
|
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
|
|
|
|
|
|
|
|
x = mx.random.uniform(shape=(1, h, 6, d))
|
|
|
|
k, v = cache.update_and_fetch(x, x)
|
|
|
|
self.assertEqual(cache.offset, 20)
|
|
|
|
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
|
|
|
|
|
|
|
|
x = mx.random.uniform(shape=(1, h, 2, d))
|
|
|
|
k, v = cache.update_and_fetch(x, x)
|
|
|
|
self.assertEqual(cache.offset, 22)
|
|
|
|
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
|
|
|
|
2024-03-08 01:31:57 +08:00
|
|
|
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
|
|
|
|
|
|
|
self.assertEqual(len(model.layers), num_layers)
|
|
|
|
self.assertEqual(model.model_type, model_type)
|
|
|
|
|
|
|
|
for t in [mx.float32, mx.float16]:
|
|
|
|
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
|
|
|
|
|
|
|
|
inputs = mx.array([[0, 1]])
|
2024-05-08 23:18:13 +08:00
|
|
|
outputs = model(inputs)
|
2024-03-08 01:31:57 +08:00
|
|
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
|
|
|
self.assertEqual(outputs.dtype, t)
|
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
cache = make_prompt_cache(model)
|
2024-05-08 23:35:54 +08:00
|
|
|
outputs = model(inputs, cache)
|
|
|
|
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
|
|
|
self.assertEqual(outputs.dtype, t)
|
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
|
2024-03-08 01:31:57 +08:00
|
|
|
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
|
|
|
self.assertEqual(outputs.dtype, t)
|
|
|
|
|
|
|
|
def test_llama(self):
|
|
|
|
from mlx_lm.models import llama
|
|
|
|
|
|
|
|
args = llama.ModelArgs(
|
|
|
|
model_type="llama",
|
|
|
|
hidden_size=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=4,
|
|
|
|
rms_norm_eps=1e-5,
|
|
|
|
vocab_size=10_000,
|
|
|
|
)
|
|
|
|
model = llama.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_phi2(self):
|
|
|
|
from mlx_lm.models import phi
|
|
|
|
|
|
|
|
args = phi.ModelArgs()
|
|
|
|
model = phi.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
def test_phixtral(self):
|
|
|
|
from mlx_lm.models import phixtral
|
|
|
|
|
|
|
|
args = phixtral.ModelArgs(
|
|
|
|
"phixtral", num_vocab=1000, num_layers=4, model_dim=1024
|
|
|
|
)
|
|
|
|
model = phixtral.Model(args)
|
|
|
|
self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers)
|
|
|
|
|
2024-04-24 00:20:00 +08:00
|
|
|
def test_phi3(self):
|
|
|
|
from mlx_lm.models import phi3
|
|
|
|
|
|
|
|
args = phi3.ModelArgs(
|
|
|
|
model_type="phi3",
|
|
|
|
hidden_size=3072,
|
|
|
|
num_hidden_layers=32,
|
|
|
|
intermediate_size=8192,
|
|
|
|
num_attention_heads=32,
|
|
|
|
rms_norm_eps=1e-5,
|
|
|
|
vocab_size=32064,
|
|
|
|
)
|
|
|
|
model = phi3.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-03-08 01:31:57 +08:00
|
|
|
def test_gemma(self):
|
|
|
|
from mlx_lm.models import gemma
|
|
|
|
|
|
|
|
args = gemma.ModelArgs(
|
|
|
|
model_type="gemma",
|
|
|
|
hidden_size=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=4,
|
|
|
|
head_dim=128,
|
|
|
|
rms_norm_eps=1e-5,
|
|
|
|
vocab_size=10_000,
|
|
|
|
num_key_value_heads=4,
|
|
|
|
)
|
|
|
|
model = gemma.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_mixtral(self):
|
|
|
|
from mlx_lm.models import mixtral
|
|
|
|
|
|
|
|
# Make a baby mixtral, because it will actually do the
|
|
|
|
# eval
|
|
|
|
args = mixtral.ModelArgs(
|
|
|
|
model_type="mixtral",
|
|
|
|
vocab_size=100,
|
|
|
|
hidden_size=32,
|
|
|
|
intermediate_size=128,
|
|
|
|
num_hidden_layers=2,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_experts_per_tok=2,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
num_local_experts=4,
|
|
|
|
)
|
|
|
|
model = mixtral.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
@unittest.skip("requires ai2-olmo")
|
|
|
|
def test_olmo(self):
|
|
|
|
from mlx_lm.models import olmo
|
|
|
|
|
|
|
|
args = olmo.ModelArgs(
|
|
|
|
model_type="olmo",
|
|
|
|
d_model=1024,
|
|
|
|
n_layers=4,
|
|
|
|
mlp_hidden_size=2048,
|
|
|
|
n_heads=2,
|
|
|
|
vocab_size=10_000,
|
|
|
|
embedding_size=10_000,
|
|
|
|
)
|
|
|
|
model = olmo.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model,
|
|
|
|
args.model_type,
|
|
|
|
args.vocab_size,
|
|
|
|
args.n_layers,
|
|
|
|
)
|
|
|
|
|
2024-04-03 02:33:29 +08:00
|
|
|
def test_qwen2_moe(self):
|
|
|
|
from mlx_lm.models import qwen2_moe
|
|
|
|
|
|
|
|
args = qwen2_moe.ModelArgs(
|
|
|
|
model_type="qwen2_moe",
|
|
|
|
hidden_size=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=4,
|
|
|
|
rms_norm_eps=1e-5,
|
|
|
|
vocab_size=10_000,
|
|
|
|
num_experts_per_tok=4,
|
|
|
|
num_experts=16,
|
|
|
|
moe_intermediate_size=1024,
|
|
|
|
shared_expert_intermediate_size=2048,
|
|
|
|
)
|
|
|
|
model = qwen2_moe.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-03-08 01:31:57 +08:00
|
|
|
def test_qwen2(self):
|
|
|
|
from mlx_lm.models import qwen2
|
|
|
|
|
|
|
|
args = qwen2.ModelArgs(
|
|
|
|
model_type="qwen2",
|
|
|
|
hidden_size=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=4,
|
|
|
|
rms_norm_eps=1e-5,
|
|
|
|
vocab_size=10_000,
|
|
|
|
)
|
|
|
|
model = qwen2.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_qwen(self):
|
|
|
|
from mlx_lm.models import qwen
|
|
|
|
|
|
|
|
args = qwen.ModelArgs(
|
|
|
|
model_type="qwen",
|
|
|
|
)
|
|
|
|
model = qwen.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_plamo(self):
|
|
|
|
from mlx_lm.models import plamo
|
|
|
|
|
|
|
|
args = plamo.ModelArgs(
|
|
|
|
model_type="plamo",
|
|
|
|
hidden_size=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=8,
|
|
|
|
rms_norm_eps=1e-5,
|
|
|
|
vocab_size=10_000,
|
|
|
|
)
|
|
|
|
model = plamo.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_stablelm(self):
|
|
|
|
from mlx_lm.models import stablelm
|
|
|
|
|
|
|
|
args = stablelm.ModelArgs(
|
|
|
|
model_type="stablelm",
|
|
|
|
vocab_size=10_000,
|
|
|
|
hidden_size=1024,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
partial_rotary_factor=1.0,
|
|
|
|
intermediate_size=2048,
|
|
|
|
layer_norm_eps=1e-2,
|
|
|
|
rope_theta=10_000,
|
|
|
|
use_qkv_bias=False,
|
|
|
|
)
|
|
|
|
model = stablelm.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-04-09 05:18:55 +08:00
|
|
|
# StableLM 2
|
|
|
|
args = stablelm.ModelArgs(
|
|
|
|
model_type="stablelm",
|
|
|
|
vocab_size=10000,
|
|
|
|
hidden_size=512,
|
|
|
|
num_attention_heads=8,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
partial_rotary_factor=0.25,
|
|
|
|
intermediate_size=1024,
|
|
|
|
layer_norm_eps=1e-5,
|
|
|
|
rope_theta=10000,
|
|
|
|
use_qkv_bias=True,
|
|
|
|
)
|
|
|
|
model = stablelm.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-03-08 01:31:57 +08:00
|
|
|
def test_starcoder2(self):
|
|
|
|
from mlx_lm.models import starcoder2
|
|
|
|
|
|
|
|
args = starcoder2.ModelArgs(
|
|
|
|
model_type="starcoder2",
|
|
|
|
hidden_size=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_key_value_heads=4,
|
|
|
|
)
|
|
|
|
model = starcoder2.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-03-13 22:03:36 +08:00
|
|
|
def test_cohere(self):
|
|
|
|
from mlx_lm.models import cohere
|
|
|
|
|
|
|
|
args = cohere.ModelArgs(
|
|
|
|
model_type="cohere",
|
|
|
|
)
|
|
|
|
model = cohere.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
def test_dbrx(self):
|
|
|
|
from mlx_lm.models import dbrx
|
|
|
|
|
|
|
|
args = dbrx.ModelArgs(
|
|
|
|
model_type="dbrx",
|
|
|
|
d_model=1024,
|
|
|
|
ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2},
|
|
|
|
attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000},
|
|
|
|
n_layers=4,
|
|
|
|
n_heads=4,
|
|
|
|
vocab_size=10_000,
|
|
|
|
)
|
|
|
|
model = dbrx.Model(args)
|
|
|
|
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers)
|
|
|
|
|
|
|
|
def test_minicpm(self):
|
|
|
|
from mlx_lm.models import minicpm
|
|
|
|
|
|
|
|
args = minicpm.ModelArgs(
|
|
|
|
model_type="minicpm",
|
|
|
|
hidden_size=1024,
|
|
|
|
dim_model_base=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=4,
|
|
|
|
rms_norm_eps=1e-4,
|
|
|
|
vocab_size=10000,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
scale_depth=1.0,
|
|
|
|
scale_emb=1.0,
|
|
|
|
)
|
|
|
|
model = minicpm.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-09-28 22:02:53 +08:00
|
|
|
def test_mamba(self):
|
|
|
|
from mlx_lm.models import mamba
|
|
|
|
|
|
|
|
args = mamba.ModelArgs(
|
|
|
|
model_type="mamba",
|
|
|
|
vocab_size=10000,
|
|
|
|
use_bias=False,
|
|
|
|
use_conv_bias=True,
|
|
|
|
conv_kernel=4,
|
|
|
|
hidden_size=768,
|
|
|
|
num_hidden_layers=24,
|
|
|
|
state_size=16,
|
|
|
|
intermediate_size=1536,
|
|
|
|
time_step_rank=48,
|
|
|
|
)
|
|
|
|
model = mamba.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-06-03 07:33:20 +08:00
|
|
|
def test_gpt2(self):
|
|
|
|
from mlx_lm.models import gpt2
|
|
|
|
|
|
|
|
args = gpt2.ModelArgs(
|
|
|
|
model_type="gpt2",
|
|
|
|
n_ctx=1024,
|
|
|
|
n_embd=768,
|
|
|
|
n_head=12,
|
|
|
|
n_layer=12,
|
|
|
|
n_positions=1024,
|
|
|
|
layer_norm_epsilon=1e-5,
|
|
|
|
vocab_size=50256,
|
|
|
|
)
|
|
|
|
model = gpt2.Model(args)
|
|
|
|
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
|
|
|
|
2024-07-11 21:13:17 +08:00
|
|
|
def test_gpt_neox(self):
|
|
|
|
from mlx_lm.models import gpt_neox
|
|
|
|
|
|
|
|
args = gpt_neox.ModelArgs(
|
|
|
|
model_type="gpt_neox",
|
|
|
|
max_position_embeddings=2048,
|
|
|
|
hidden_size=6144,
|
|
|
|
num_attention_heads=64,
|
|
|
|
num_hidden_layers=44,
|
|
|
|
layer_norm_eps=1e-5,
|
|
|
|
vocab_size=50432,
|
|
|
|
rotary_emb_base=10_000,
|
|
|
|
rotary_pct=0.25,
|
|
|
|
)
|
|
|
|
model = gpt_neox.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-05-08 23:18:13 +08:00
|
|
|
def test_openelm(self):
|
|
|
|
from mlx_lm.models import openelm
|
|
|
|
|
|
|
|
args = openelm.ModelArgs(
|
|
|
|
model_type="openelm",
|
|
|
|
ffn_dim_divisor=256,
|
|
|
|
ffn_multipliers=[
|
|
|
|
0.5,
|
|
|
|
0.73,
|
|
|
|
0.97,
|
|
|
|
1.2,
|
|
|
|
1.43,
|
|
|
|
1.67,
|
|
|
|
1.9,
|
|
|
|
2.13,
|
|
|
|
2.37,
|
|
|
|
2.6,
|
|
|
|
2.83,
|
|
|
|
3.07,
|
|
|
|
3.3,
|
|
|
|
3.53,
|
|
|
|
3.77,
|
|
|
|
4.0,
|
|
|
|
],
|
|
|
|
head_dim=64,
|
|
|
|
model_dim=1280,
|
|
|
|
normalize_qk_projections=True,
|
|
|
|
num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5],
|
|
|
|
num_query_heads=[
|
|
|
|
12,
|
|
|
|
12,
|
|
|
|
12,
|
|
|
|
12,
|
|
|
|
12,
|
|
|
|
16,
|
|
|
|
16,
|
|
|
|
16,
|
|
|
|
16,
|
|
|
|
16,
|
|
|
|
16,
|
|
|
|
16,
|
|
|
|
20,
|
|
|
|
20,
|
|
|
|
20,
|
|
|
|
20,
|
|
|
|
],
|
|
|
|
num_transformer_layers=16,
|
|
|
|
vocab_size=32000,
|
|
|
|
)
|
|
|
|
|
|
|
|
model = openelm.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model,
|
|
|
|
args.model_type,
|
|
|
|
args.vocab_size,
|
|
|
|
len(args.ffn_multipliers),
|
|
|
|
)
|
|
|
|
|
2024-05-27 21:22:21 +08:00
|
|
|
def test_internlm2(self):
|
|
|
|
from mlx_lm.models import internlm2
|
|
|
|
|
|
|
|
args = internlm2.ModelArgs(
|
|
|
|
model_type="internlm2",
|
|
|
|
hidden_size=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=4,
|
|
|
|
rms_norm_eps=1e-5,
|
|
|
|
vocab_size=10000,
|
|
|
|
)
|
|
|
|
model = internlm2.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-07-24 04:21:32 +08:00
|
|
|
def test_llama3_1(self):
|
|
|
|
from mlx_lm.models import llama
|
|
|
|
|
|
|
|
args = llama.ModelArgs(
|
|
|
|
model_type="llama",
|
|
|
|
hidden_size=1024,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=2048,
|
|
|
|
num_attention_heads=4,
|
|
|
|
rms_norm_eps=1e-5,
|
|
|
|
vocab_size=10_000,
|
|
|
|
max_position_embeddings=128,
|
|
|
|
mlp_bias=False,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
rope_scaling={
|
|
|
|
"factor": 8.0,
|
|
|
|
"low_freq_factor": 1.0,
|
|
|
|
"high_freq_factor": 4.0,
|
|
|
|
"original_max_position_embeddings": 8192,
|
|
|
|
"rope_type": "llama3",
|
|
|
|
},
|
|
|
|
)
|
|
|
|
model = llama.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-10-08 11:45:51 +08:00
|
|
|
def test_deepseek(self):
|
|
|
|
from mlx_lm.models import deepseek
|
|
|
|
|
|
|
|
args = deepseek.ModelArgs(
|
|
|
|
model_type="deepseek",
|
|
|
|
vocab_size=1024,
|
|
|
|
hidden_size=128,
|
|
|
|
intermediate_size=256,
|
|
|
|
moe_intermediate_size=256,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_attention_heads=8,
|
|
|
|
num_key_value_heads=4,
|
|
|
|
)
|
|
|
|
model = deepseek.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_deepseek_v2(self):
|
|
|
|
from mlx_lm.models import deepseek_v2
|
|
|
|
|
|
|
|
args = deepseek_v2.ModelArgs(
|
|
|
|
model_type="deepseek_v2",
|
|
|
|
vocab_size=1024,
|
|
|
|
hidden_size=128,
|
|
|
|
intermediate_size=256,
|
|
|
|
moe_intermediate_size=256,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
kv_lora_rank=4,
|
|
|
|
q_lora_rank=4,
|
|
|
|
qk_rope_head_dim=32,
|
|
|
|
v_head_dim=16,
|
|
|
|
qk_nope_head_dim=32,
|
|
|
|
rope_scaling={
|
|
|
|
"beta_fast": 32,
|
|
|
|
"beta_slow": 1,
|
|
|
|
"factor": 40,
|
|
|
|
"mscale": 1.0,
|
|
|
|
"mscale_all_dim": 1.0,
|
|
|
|
"original_max_position_embeddings": 4096,
|
|
|
|
"type": "yarn",
|
|
|
|
},
|
|
|
|
)
|
|
|
|
model = deepseek_v2.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_gemma2(self):
|
|
|
|
from mlx_lm.models import gemma2
|
|
|
|
|
|
|
|
args = gemma2.ModelArgs(
|
|
|
|
model_type="gemma2",
|
|
|
|
hidden_size=128,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=256,
|
|
|
|
num_attention_heads=2,
|
|
|
|
head_dim=32,
|
|
|
|
rms_norm_eps=1e-4,
|
|
|
|
vocab_size=1024,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
)
|
|
|
|
model = gemma2.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_gpt_bigcode(self):
|
|
|
|
from mlx_lm.models import gpt_bigcode
|
|
|
|
|
|
|
|
args = gpt_bigcode.ModelArgs(
|
|
|
|
model_type="gpt_bigcode",
|
|
|
|
n_embd=128,
|
|
|
|
n_layer=128,
|
|
|
|
n_inner=256,
|
|
|
|
n_head=4,
|
|
|
|
n_positions=1000,
|
|
|
|
layer_norm_epsilon=1e-5,
|
|
|
|
vocab_size=1024,
|
|
|
|
)
|
|
|
|
model = gpt_bigcode.Model(args)
|
|
|
|
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
|
|
|
|
|
|
|
def test_nemotron(self):
|
|
|
|
from mlx_lm.models import nemotron
|
|
|
|
|
|
|
|
args = nemotron.ModelArgs(
|
|
|
|
model_type="nemotron",
|
|
|
|
hidden_size=128,
|
|
|
|
hidden_act="gelu",
|
|
|
|
num_hidden_layers=4,
|
|
|
|
intermediate_size=256,
|
|
|
|
num_attention_heads=4,
|
|
|
|
norm_eps=1e-5,
|
|
|
|
vocab_size=1024,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
)
|
|
|
|
model = nemotron.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_phi3small(self):
|
|
|
|
from mlx_lm.models import phi3small
|
|
|
|
|
|
|
|
args = phi3small.ModelArgs(
|
|
|
|
model_type="phi3small",
|
|
|
|
hidden_size=128,
|
|
|
|
dense_attention_every_n_layers=2,
|
|
|
|
ff_intermediate_size=256,
|
|
|
|
gegelu_limit=1.0,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
layer_norm_epsilon=1e-4,
|
|
|
|
vocab_size=1000,
|
|
|
|
)
|
|
|
|
model = phi3small.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_phimoe(self):
|
|
|
|
from mlx_lm.models import phimoe
|
|
|
|
|
|
|
|
args = phimoe.ModelArgs(
|
|
|
|
model_type="phimoe",
|
|
|
|
vocab_size=320,
|
|
|
|
hidden_size=128,
|
|
|
|
intermediate_size=256,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_key_value_heads=4,
|
|
|
|
rope_scaling={
|
|
|
|
"long_factor": [1.0] * 16,
|
|
|
|
"long_mscale": 1.243163121016122,
|
|
|
|
"original_max_position_embeddings": 4096,
|
|
|
|
"short_factor": [1.0] * 16,
|
|
|
|
"short_mscale": 1.243163121016122,
|
|
|
|
"type": "longrope",
|
|
|
|
},
|
|
|
|
)
|
|
|
|
model = phimoe.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
|
|
|
def test_recurrent_gemma(self):
|
|
|
|
from mlx_lm.models import recurrent_gemma
|
|
|
|
|
|
|
|
args = recurrent_gemma.ModelArgs(
|
|
|
|
model_type="recurrent_gemma",
|
|
|
|
hidden_size=128,
|
|
|
|
attention_bias=False,
|
|
|
|
conv1d_width=3,
|
|
|
|
intermediate_size=256,
|
|
|
|
logits_soft_cap=1.0,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
rms_norm_eps=1e-4,
|
|
|
|
rope_theta=1000,
|
|
|
|
attention_window_size=1024,
|
|
|
|
vocab_size=1000,
|
|
|
|
block_types=["recurrent", "recurrent", "attention"],
|
|
|
|
)
|
|
|
|
model = recurrent_gemma.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-11-24 03:06:26 +08:00
|
|
|
def test_hunyuan(self):
|
|
|
|
from mlx_lm.models import hunyuan
|
|
|
|
|
|
|
|
args = hunyuan.ModelArgs(
|
|
|
|
model_type="hunyuan",
|
|
|
|
hidden_size=128,
|
|
|
|
attention_bias=False,
|
|
|
|
intermediate_size=256,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
rms_norm_eps=1e-4,
|
|
|
|
rope_theta=1000,
|
|
|
|
vocab_size=1000,
|
|
|
|
moe_topk=2,
|
|
|
|
num_experts=2,
|
|
|
|
num_shared_expert=1,
|
|
|
|
use_mixed_mlp_moe=True,
|
|
|
|
use_qk_norm=True,
|
|
|
|
rope_scaling={
|
|
|
|
"alpha": 1000.0,
|
|
|
|
"factor": 1.0,
|
|
|
|
"type": "dynamic",
|
|
|
|
},
|
|
|
|
use_cla=True,
|
|
|
|
cla_share_factor=2,
|
|
|
|
)
|
|
|
|
model = hunyuan.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-12-03 03:42:58 +08:00
|
|
|
def test_olmo2(self):
|
|
|
|
from mlx_lm.models import olmo2
|
|
|
|
|
|
|
|
args = olmo2.ModelArgs(
|
|
|
|
model_type="olmo2",
|
|
|
|
hidden_size=128,
|
|
|
|
attention_bias=False,
|
|
|
|
intermediate_size=256,
|
|
|
|
num_attention_heads=4,
|
|
|
|
num_hidden_layers=4,
|
|
|
|
num_key_value_heads=2,
|
|
|
|
rms_norm_eps=1e-4,
|
|
|
|
rope_theta=1000,
|
|
|
|
vocab_size=1000,
|
|
|
|
)
|
|
|
|
model = olmo2.Model(args)
|
|
|
|
self.model_test_runner(
|
|
|
|
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
|
|
|
)
|
|
|
|
|
2024-03-08 01:31:57 +08:00
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
unittest.main()
|