mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
Kv cache (#643)
* in place kv_cache * fix * fix kv cache size * partially fix kv cache dtype * step kv cache * multiple of step size * more teests + kv cache * more kv cache * udpate all models to use kv cache
This commit is contained in:
@@ -4,6 +4,7 @@ import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models.base import KVCache
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
@@ -17,13 +18,18 @@ class TestModels(unittest.TestCase):
|
||||
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
|
||||
|
||||
inputs = mx.array([[0, 1]])
|
||||
outputs, cache = model(inputs)
|
||||
outputs = model(inputs)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
outputs, cache = model(
|
||||
mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache
|
||||
kv_heads = (
|
||||
[model.n_kv_heads] * len(model.layers)
|
||||
if isinstance(model.n_kv_heads, int)
|
||||
else model.n_kv_heads
|
||||
)
|
||||
cache = [KVCache(model.head_dim, n) for n in kv_heads]
|
||||
|
||||
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
|
||||
self.assertEqual(outputs.shape, (1, 1, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
@@ -53,6 +59,15 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phixtral(self):
|
||||
from mlx_lm.models import phixtral
|
||||
|
||||
args = phixtral.ModelArgs(
|
||||
"phixtral", num_vocab=1000, num_layers=4, model_dim=1024
|
||||
)
|
||||
model = phixtral.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers)
|
||||
|
||||
def test_phi3(self):
|
||||
from mlx_lm.models import phi3
|
||||
|
||||
@@ -264,6 +279,100 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_dbrx(self):
|
||||
from mlx_lm.models import dbrx
|
||||
|
||||
args = dbrx.ModelArgs(
|
||||
model_type="dbrx",
|
||||
d_model=1024,
|
||||
ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2},
|
||||
attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000},
|
||||
n_layers=4,
|
||||
n_heads=4,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
model = dbrx.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers)
|
||||
|
||||
def test_minicpm(self):
|
||||
from mlx_lm.models import minicpm
|
||||
|
||||
args = minicpm.ModelArgs(
|
||||
model_type="minicpm",
|
||||
hidden_size=1024,
|
||||
dim_model_base=1024,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=2048,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-4,
|
||||
vocab_size=10000,
|
||||
num_key_value_heads=2,
|
||||
scale_depth=1.0,
|
||||
scale_emb=1.0,
|
||||
)
|
||||
model = minicpm.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_openelm(self):
|
||||
from mlx_lm.models import openelm
|
||||
|
||||
args = openelm.ModelArgs(
|
||||
model_type="openelm",
|
||||
ffn_dim_divisor=256,
|
||||
ffn_multipliers=[
|
||||
0.5,
|
||||
0.73,
|
||||
0.97,
|
||||
1.2,
|
||||
1.43,
|
||||
1.67,
|
||||
1.9,
|
||||
2.13,
|
||||
2.37,
|
||||
2.6,
|
||||
2.83,
|
||||
3.07,
|
||||
3.3,
|
||||
3.53,
|
||||
3.77,
|
||||
4.0,
|
||||
],
|
||||
head_dim=64,
|
||||
model_dim=1280,
|
||||
normalize_qk_projections=True,
|
||||
num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5],
|
||||
num_query_heads=[
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
12,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
16,
|
||||
20,
|
||||
20,
|
||||
20,
|
||||
20,
|
||||
],
|
||||
num_transformer_layers=16,
|
||||
vocab_size=32000,
|
||||
)
|
||||
|
||||
model = openelm.Model(args)
|
||||
self.model_test_runner(
|
||||
model,
|
||||
args.model_type,
|
||||
args.vocab_size,
|
||||
len(args.ffn_multipliers),
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user