* in place kv_cache

* fix

* fix kv cache size

* partially fix kv cache dtype

* step kv cache

* multiple of step size

* more teests + kv cache

* more kv cache

* udpate all models to use kv cache
This commit is contained in:
Awni Hannun
2024-05-08 08:18:13 -07:00
committed by GitHub
parent bfbc0e434a
commit ee60e2a9d5
22 changed files with 534 additions and 298 deletions

View File

@@ -4,6 +4,7 @@ import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache
class TestModels(unittest.TestCase):
@@ -17,13 +18,18 @@ class TestModels(unittest.TestCase):
model.update(tree_map(lambda p: p.astype(t), model.parameters()))
inputs = mx.array([[0, 1]])
outputs, cache = model(inputs)
outputs = model(inputs)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
outputs, cache = model(
mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache
kv_heads = (
[model.n_kv_heads] * len(model.layers)
if isinstance(model.n_kv_heads, int)
else model.n_kv_heads
)
cache = [KVCache(model.head_dim, n) for n in kv_heads]
outputs = model(mx.argmax(outputs[0, -1:, :], keepdims=True), cache=cache)
self.assertEqual(outputs.shape, (1, 1, vocab_size))
self.assertEqual(outputs.dtype, t)
@@ -53,6 +59,15 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phixtral(self):
from mlx_lm.models import phixtral
args = phixtral.ModelArgs(
"phixtral", num_vocab=1000, num_layers=4, model_dim=1024
)
model = phixtral.Model(args)
self.model_test_runner(model, args.model_type, args.num_vocab, args.num_layers)
def test_phi3(self):
from mlx_lm.models import phi3
@@ -264,6 +279,100 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_dbrx(self):
from mlx_lm.models import dbrx
args = dbrx.ModelArgs(
model_type="dbrx",
d_model=1024,
ffn_config={"ffn_hidden_size": 2048, "moe_num_experts": 4, "moe_top_k": 2},
attn_config={"kv_n_heads": 2, "clip_qkv": True, "rope_theta": 10000},
n_layers=4,
n_heads=4,
vocab_size=10_000,
)
model = dbrx.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layers)
def test_minicpm(self):
from mlx_lm.models import minicpm
args = minicpm.ModelArgs(
model_type="minicpm",
hidden_size=1024,
dim_model_base=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-4,
vocab_size=10000,
num_key_value_heads=2,
scale_depth=1.0,
scale_emb=1.0,
)
model = minicpm.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_openelm(self):
from mlx_lm.models import openelm
args = openelm.ModelArgs(
model_type="openelm",
ffn_dim_divisor=256,
ffn_multipliers=[
0.5,
0.73,
0.97,
1.2,
1.43,
1.67,
1.9,
2.13,
2.37,
2.6,
2.83,
3.07,
3.3,
3.53,
3.77,
4.0,
],
head_dim=64,
model_dim=1280,
normalize_qk_projections=True,
num_kv_heads=[3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 5, 5, 5, 5],
num_query_heads=[
12,
12,
12,
12,
12,
16,
16,
16,
16,
16,
16,
16,
20,
20,
20,
20,
],
num_transformer_layers=16,
vocab_size=32000,
)
model = openelm.Model(args)
self.model_test_runner(
model,
args.model_type,
args.vocab_size,
len(args.ffn_multipliers),
)
if __name__ == "__main__":
unittest.main()