mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
More cache improvements (#1015)
* fix rotating kv cache for chat use case * reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat * nit in chat * fix tests * fix tests * fix tests * docs * chat command * comments + docs * Define meta_state on all Cache implementations * fixes + trim_prompt_cache api * fix default model --------- Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
@@ -1,17 +1,15 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import unittest
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_map
|
||||
from mlx_lm.models.base import KVCache, RotatingKVCache
|
||||
from mlx_lm.utils import make_kv_caches
|
||||
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
|
||||
|
||||
|
||||
class TestModels(unittest.TestCase):
|
||||
|
||||
def test_kv_cache(self):
|
||||
cache = KVCache(32, 4)
|
||||
cache = KVCache()
|
||||
|
||||
k = mx.ones((1, 4, 1, 32), mx.float16)
|
||||
v = mx.ones((1, 4, 1, 32), mx.float16)
|
||||
@@ -32,7 +30,7 @@ class TestModels(unittest.TestCase):
|
||||
|
||||
def test_rotating_kv_cache(self):
|
||||
b, h, d = 1, 2, 32
|
||||
cache = RotatingKVCache(d, h, max_size=8, step=4)
|
||||
cache = RotatingKVCache(max_size=8, step=4)
|
||||
|
||||
k = mx.random.uniform(shape=(b, h, 2, d))
|
||||
v = mx.random.uniform(shape=(b, h, 2, d))
|
||||
@@ -65,7 +63,7 @@ class TestModels(unittest.TestCase):
|
||||
idx %= 8
|
||||
|
||||
# Try with nonzero keep
|
||||
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
|
||||
cache = RotatingKVCache(max_size=8, step=4, keep=2)
|
||||
|
||||
# Check a large update
|
||||
k = mx.random.uniform(shape=(b, h, 20, d))
|
||||
@@ -88,6 +86,46 @@ class TestModels(unittest.TestCase):
|
||||
if idx >= 8:
|
||||
idx = 2
|
||||
|
||||
def test_rotating_kv_cache_chat_mode(self):
|
||||
# Test that the rotating kv cache can handle
|
||||
# alternating prompt/prefill with generation
|
||||
d = 4
|
||||
h = 2
|
||||
cache = RotatingKVCache(max_size=18, step=4)
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 8, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 8)
|
||||
self.assertEqual(cache.offset, 8)
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 1, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 9)
|
||||
self.assertEqual(cache.offset, 9)
|
||||
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 11)
|
||||
self.assertEqual(cache.offset, 11)
|
||||
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 3, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(k.shape[2], 14)
|
||||
self.assertEqual(cache.offset, 14)
|
||||
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 6, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(cache.offset, 20)
|
||||
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
|
||||
|
||||
x = mx.random.uniform(shape=(1, h, 2, d))
|
||||
k, v = cache.update_and_fetch(x, x)
|
||||
self.assertEqual(cache.offset, 22)
|
||||
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
|
||||
|
||||
def model_test_runner(self, model, model_type, vocab_size, num_layers):
|
||||
|
||||
self.assertEqual(len(model.layers), num_layers)
|
||||
@@ -101,7 +139,7 @@ class TestModels(unittest.TestCase):
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
|
||||
cache = make_kv_caches(model)
|
||||
cache = make_prompt_cache(model)
|
||||
outputs = model(inputs, cache)
|
||||
self.assertEqual(outputs.shape, (1, 2, vocab_size))
|
||||
self.assertEqual(outputs.dtype, t)
|
||||
@@ -549,6 +587,179 @@ class TestModels(unittest.TestCase):
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek(self):
|
||||
from mlx_lm.models import deepseek
|
||||
|
||||
args = deepseek.ModelArgs(
|
||||
model_type="deepseek",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=8,
|
||||
num_key_value_heads=4,
|
||||
)
|
||||
model = deepseek.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_deepseek_v2(self):
|
||||
from mlx_lm.models import deepseek_v2
|
||||
|
||||
args = deepseek_v2.ModelArgs(
|
||||
model_type="deepseek_v2",
|
||||
vocab_size=1024,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
moe_intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
kv_lora_rank=4,
|
||||
q_lora_rank=4,
|
||||
qk_rope_head_dim=32,
|
||||
v_head_dim=16,
|
||||
qk_nope_head_dim=32,
|
||||
rope_scaling={
|
||||
"beta_fast": 32,
|
||||
"beta_slow": 1,
|
||||
"factor": 40,
|
||||
"mscale": 1.0,
|
||||
"mscale_all_dim": 1.0,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"type": "yarn",
|
||||
},
|
||||
)
|
||||
model = deepseek_v2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gemma2(self):
|
||||
from mlx_lm.models import gemma2
|
||||
|
||||
args = gemma2.ModelArgs(
|
||||
model_type="gemma2",
|
||||
hidden_size=128,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=2,
|
||||
head_dim=32,
|
||||
rms_norm_eps=1e-4,
|
||||
vocab_size=1024,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
model = gemma2.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_gpt_bigcode(self):
|
||||
from mlx_lm.models import gpt_bigcode
|
||||
|
||||
args = gpt_bigcode.ModelArgs(
|
||||
model_type="gpt_bigcode",
|
||||
n_embd=128,
|
||||
n_layer=128,
|
||||
n_inner=256,
|
||||
n_head=4,
|
||||
n_positions=1000,
|
||||
layer_norm_epsilon=1e-5,
|
||||
vocab_size=1024,
|
||||
)
|
||||
model = gpt_bigcode.Model(args)
|
||||
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
|
||||
|
||||
def test_nemotron(self):
|
||||
from mlx_lm.models import nemotron
|
||||
|
||||
args = nemotron.ModelArgs(
|
||||
model_type="nemotron",
|
||||
hidden_size=128,
|
||||
hidden_act="gelu",
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=256,
|
||||
num_attention_heads=4,
|
||||
norm_eps=1e-5,
|
||||
vocab_size=1024,
|
||||
num_key_value_heads=2,
|
||||
)
|
||||
model = nemotron.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phi3small(self):
|
||||
from mlx_lm.models import phi3small
|
||||
|
||||
args = phi3small.ModelArgs(
|
||||
model_type="phi3small",
|
||||
hidden_size=128,
|
||||
dense_attention_every_n_layers=2,
|
||||
ff_intermediate_size=256,
|
||||
gegelu_limit=1.0,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=2,
|
||||
layer_norm_epsilon=1e-4,
|
||||
vocab_size=1000,
|
||||
)
|
||||
model = phi3small.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_phimoe(self):
|
||||
from mlx_lm.models import phimoe
|
||||
|
||||
args = phimoe.ModelArgs(
|
||||
model_type="phimoe",
|
||||
vocab_size=320,
|
||||
hidden_size=128,
|
||||
intermediate_size=256,
|
||||
num_hidden_layers=4,
|
||||
num_attention_heads=4,
|
||||
num_key_value_heads=4,
|
||||
rope_scaling={
|
||||
"long_factor": [1.0] * 16,
|
||||
"long_mscale": 1.243163121016122,
|
||||
"original_max_position_embeddings": 4096,
|
||||
"short_factor": [1.0] * 16,
|
||||
"short_mscale": 1.243163121016122,
|
||||
"type": "longrope",
|
||||
},
|
||||
)
|
||||
model = phimoe.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
def test_recurrent_gemma(self):
|
||||
from mlx_lm.models import recurrent_gemma
|
||||
|
||||
args = recurrent_gemma.ModelArgs(
|
||||
model_type="recurrent_gemma",
|
||||
hidden_size=128,
|
||||
attention_bias=False,
|
||||
conv1d_width=3,
|
||||
intermediate_size=256,
|
||||
logits_soft_cap=1.0,
|
||||
num_attention_heads=4,
|
||||
num_hidden_layers=4,
|
||||
num_key_value_heads=2,
|
||||
rms_norm_eps=1e-4,
|
||||
rope_theta=1000,
|
||||
attention_window_size=1024,
|
||||
vocab_size=1000,
|
||||
block_types=["recurrent", "recurrent", "attention"],
|
||||
)
|
||||
model = recurrent_gemma.Model(args)
|
||||
self.model_test_runner(
|
||||
model, args.model_type, args.vocab_size, args.num_hidden_layers
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user