More cache improvements (#1015)

* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-10-07 20:45:51 -07:00
committed by GitHub
parent 9bc53fc210
commit fca087be49
43 changed files with 1151 additions and 691 deletions

View File

@@ -1,17 +1,15 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_map
from mlx_lm.models.base import KVCache, RotatingKVCache
from mlx_lm.utils import make_kv_caches
from mlx_lm.models.cache import KVCache, RotatingKVCache, make_prompt_cache
class TestModels(unittest.TestCase):
def test_kv_cache(self):
cache = KVCache(32, 4)
cache = KVCache()
k = mx.ones((1, 4, 1, 32), mx.float16)
v = mx.ones((1, 4, 1, 32), mx.float16)
@@ -32,7 +30,7 @@ class TestModels(unittest.TestCase):
def test_rotating_kv_cache(self):
b, h, d = 1, 2, 32
cache = RotatingKVCache(d, h, max_size=8, step=4)
cache = RotatingKVCache(max_size=8, step=4)
k = mx.random.uniform(shape=(b, h, 2, d))
v = mx.random.uniform(shape=(b, h, 2, d))
@@ -65,7 +63,7 @@ class TestModels(unittest.TestCase):
idx %= 8
# Try with nonzero keep
cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2)
cache = RotatingKVCache(max_size=8, step=4, keep=2)
# Check a large update
k = mx.random.uniform(shape=(b, h, 20, d))
@@ -88,6 +86,46 @@ class TestModels(unittest.TestCase):
if idx >= 8:
idx = 2
def test_rotating_kv_cache_chat_mode(self):
# Test that the rotating kv cache can handle
# alternating prompt/prefill with generation
d = 4
h = 2
cache = RotatingKVCache(max_size=18, step=4)
x = mx.random.uniform(shape=(1, h, 8, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 8)
self.assertEqual(cache.offset, 8)
x = mx.random.uniform(shape=(1, h, 1, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 9)
self.assertEqual(cache.offset, 9)
self.assertTrue(mx.allclose(x, k[..., 8:9, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 11)
self.assertEqual(cache.offset, 11)
self.assertTrue(mx.allclose(x, k[..., 9:11, :]))
x = mx.random.uniform(shape=(1, h, 3, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(k.shape[2], 14)
self.assertEqual(cache.offset, 14)
self.assertTrue(mx.allclose(x, k[..., 11:14, :]))
x = mx.random.uniform(shape=(1, h, 6, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 20)
self.assertTrue(mx.allclose(x, k[..., -6:, :]))
x = mx.random.uniform(shape=(1, h, 2, d))
k, v = cache.update_and_fetch(x, x)
self.assertEqual(cache.offset, 22)
self.assertTrue(mx.allclose(x, k[..., -2:, :]))
def model_test_runner(self, model, model_type, vocab_size, num_layers):
self.assertEqual(len(model.layers), num_layers)
@@ -101,7 +139,7 @@ class TestModels(unittest.TestCase):
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
cache = make_kv_caches(model)
cache = make_prompt_cache(model)
outputs = model(inputs, cache)
self.assertEqual(outputs.shape, (1, 2, vocab_size))
self.assertEqual(outputs.dtype, t)
@@ -549,6 +587,179 @@ class TestModels(unittest.TestCase):
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek(self):
from mlx_lm.models import deepseek
args = deepseek.ModelArgs(
model_type="deepseek",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=8,
num_key_value_heads=4,
)
model = deepseek.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_deepseek_v2(self):
from mlx_lm.models import deepseek_v2
args = deepseek_v2.ModelArgs(
model_type="deepseek_v2",
vocab_size=1024,
hidden_size=128,
intermediate_size=256,
moe_intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
kv_lora_rank=4,
q_lora_rank=4,
qk_rope_head_dim=32,
v_head_dim=16,
qk_nope_head_dim=32,
rope_scaling={
"beta_fast": 32,
"beta_slow": 1,
"factor": 40,
"mscale": 1.0,
"mscale_all_dim": 1.0,
"original_max_position_embeddings": 4096,
"type": "yarn",
},
)
model = deepseek_v2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gemma2(self):
from mlx_lm.models import gemma2
args = gemma2.ModelArgs(
model_type="gemma2",
hidden_size=128,
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=2,
head_dim=32,
rms_norm_eps=1e-4,
vocab_size=1024,
num_key_value_heads=2,
)
model = gemma2.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_gpt_bigcode(self):
from mlx_lm.models import gpt_bigcode
args = gpt_bigcode.ModelArgs(
model_type="gpt_bigcode",
n_embd=128,
n_layer=128,
n_inner=256,
n_head=4,
n_positions=1000,
layer_norm_epsilon=1e-5,
vocab_size=1024,
)
model = gpt_bigcode.Model(args)
self.model_test_runner(model, args.model_type, args.vocab_size, args.n_layer)
def test_nemotron(self):
from mlx_lm.models import nemotron
args = nemotron.ModelArgs(
model_type="nemotron",
hidden_size=128,
hidden_act="gelu",
num_hidden_layers=4,
intermediate_size=256,
num_attention_heads=4,
norm_eps=1e-5,
vocab_size=1024,
num_key_value_heads=2,
)
model = nemotron.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phi3small(self):
from mlx_lm.models import phi3small
args = phi3small.ModelArgs(
model_type="phi3small",
hidden_size=128,
dense_attention_every_n_layers=2,
ff_intermediate_size=256,
gegelu_limit=1.0,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=2,
layer_norm_epsilon=1e-4,
vocab_size=1000,
)
model = phi3small.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_phimoe(self):
from mlx_lm.models import phimoe
args = phimoe.ModelArgs(
model_type="phimoe",
vocab_size=320,
hidden_size=128,
intermediate_size=256,
num_hidden_layers=4,
num_attention_heads=4,
num_key_value_heads=4,
rope_scaling={
"long_factor": [1.0] * 16,
"long_mscale": 1.243163121016122,
"original_max_position_embeddings": 4096,
"short_factor": [1.0] * 16,
"short_mscale": 1.243163121016122,
"type": "longrope",
},
)
model = phimoe.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
def test_recurrent_gemma(self):
from mlx_lm.models import recurrent_gemma
args = recurrent_gemma.ModelArgs(
model_type="recurrent_gemma",
hidden_size=128,
attention_bias=False,
conv1d_width=3,
intermediate_size=256,
logits_soft_cap=1.0,
num_attention_heads=4,
num_hidden_layers=4,
num_key_value_heads=2,
rms_norm_eps=1e-4,
rope_theta=1000,
attention_window_size=1024,
vocab_size=1000,
block_types=["recurrent", "recurrent", "attention"],
)
model = recurrent_gemma.Model(args)
self.model_test_runner(
model, args.model_type, args.vocab_size, args.num_hidden_layers
)
if __name__ == "__main__":
unittest.main()

View File

@@ -0,0 +1,220 @@
# Copyright © 2024 Apple Inc.
import os
import tempfile
import unittest
import mlx.core as mx
from mlx_lm.models.cache import (
KVCache,
MambaCache,
RotatingKVCache,
load_prompt_cache,
make_prompt_cache,
save_prompt_cache,
trim_prompt_cache,
)
from mlx_lm.utils import generate_step, load
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestPromptCache(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_save_load(self):
cache = [KVCache() for _ in range(4)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
# Test with metadata
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
metadata = {"a": "b", "c": "d"}
save_prompt_cache(cache_file, cache, metadata)
_, loaded_metadata = load_prompt_cache(cache_file, return_metadata=True)
self.assertEqual(metadata, loaded_metadata)
def test_save_load_rotating_cache(self):
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
# Test with rotating cache
cache = [RotatingKVCache(max_size=8, keep=2) for _ in range(4)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
self.assertTrue(len(cache), len(loaded_cache))
for c, lc in zip(cache, loaded_cache):
self.assertEqual(c.offset, lc.offset)
self.assertEqual(c.keep, lc.keep)
self.assertEqual(c.max_size, lc.max_size)
self.assertEqual(c.step, lc.step)
self.assertTrue(mx.array_equal(c.state[0], lc.state[0]))
self.assertTrue(mx.array_equal(c.state[1], lc.state[1]))
# Do a couple single token updates to get a rotation
for _ in range(2):
for c in cache:
x = mx.random.uniform(shape=(1, 8, 1, 4))
c.update_and_fetch(x, x)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
for c, lc in zip(cache, loaded_cache):
x = mx.random.uniform(shape=(1, 8, 1, 4))
k, v = c.update_and_fetch(x, x)
lk, lv = lc.update_and_fetch(x, x)
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(k, lk))
self.assertTrue(mx.array_equal(v, lv))
def test_save_load_mixed_cache(self):
cache_file = os.path.join(self.test_dir, "prompt_cache.safetensors")
cache = [MambaCache(), KVCache(), RotatingKVCache(8), MambaCache()]
for c in cache:
if isinstance(c, MambaCache):
c[0] = mx.random.uniform(shape=(4, 4, 4))
c[1] = mx.random.uniform(shape=(4, 4, 4))
else:
x = mx.random.uniform(shape=(4, 4, 7, 4))
y = mx.random.uniform(shape=(4, 4, 7, 4))
c.update_and_fetch(x, y)
save_prompt_cache(cache_file, cache)
loaded_cache = load_prompt_cache(cache_file)
for c, lc in zip(cache, loaded_cache):
if isinstance(c, MambaCache):
self.assertTrue(mx.array_equal(c[0], lc[0]))
self.assertTrue(mx.array_equal(c[1], lc[1]))
else:
x = mx.random.uniform(shape=(4, 4, 1, 4))
y = mx.random.uniform(shape=(4, 4, 1, 4))
k, v = c.update_and_fetch(x, y)
lk, lv = lc.update_and_fetch(x, y)
self.assertEqual(c.offset, lc.offset)
self.assertTrue(mx.array_equal(k, lk))
self.assertTrue(mx.array_equal(v, lv))
def test_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
results = zip(range(4), generate_step(prompt, model))
toks, all_logits = zip(*(r[1] for r in results))
prompt_cache = make_prompt_cache(model)
i = 0
for _, (tok, logits) in zip(
range(2), generate_step(prompt, model, prompt_cache=prompt_cache)
):
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
i += 1
for _, (tok, logits) in zip(
range(1),
generate_step(mx.array([toks[i]]), model, prompt_cache=prompt_cache),
):
i += 1
self.assertEqual(tok, toks[i])
self.assertTrue(mx.allclose(logits, all_logits[i]))
def test_trim_cache(self):
cache = [KVCache() for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
# Trim
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 7)
# Trim more tokens than remain
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 3)
# Can't trim mamba cache
cache = [MambaCache() for _ in range(2)]
for c in cache:
c.state = mx.zeros((5, 5))
num_trimmed = trim_prompt_cache(cache, 7)
self.assertEqual(num_trimmed, 0)
# All cache's have to be trimmable
cache = [MambaCache(), KVCache()]
cache[0].state = mx.zeros((5, 5))
x = mx.random.uniform(shape=(1, 8, 10, 4))
cache[1].update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 1)
self.assertEqual(num_trimmed, 0)
cache = [RotatingKVCache(max_size=6) for _ in range(2)]
for c in cache:
x = mx.random.uniform(shape=(1, 8, 5, 4))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 4)
# Can't trim fixed-size KV cache after processing
# more than max_kv_size tokens
for c in cache:
x = mx.random.uniform(shape=(1, 8, 10, 4))
c.update_and_fetch(x, x)
num_trimmed = trim_prompt_cache(cache, 4)
self.assertEqual(num_trimmed, 0)
def test_trim_cache_with_generate(self):
model, tokenizer = load(HF_MODEL_PATH)
prompt = tokenizer.encode("this is a prompt", return_tensors="mlx")[0]
prompt_cache = make_prompt_cache(model)
# Generate one token so we process the full prompt
last_tok, _ = next(generate_step(prompt, model, prompt_cache=prompt_cache))
last_tok = mx.array([last_tok])
# Generate two more tokens
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
)
toks, all_logits = zip(*(r[1] for r in results))
# To get back to the cache just after processing the prompt,
# trim by 3 tokens
trim_prompt_cache(prompt_cache, 3)
# Generate the same thing again
results = zip(
range(2), generate_step(last_tok, model, prompt_cache=prompt_cache)
)
second_toks, second_all_logits = zip(*(r[1] for r in results))
self.assertEqual(toks, second_toks)
self.assertTrue(
all(mx.allclose(l, l2) for l, l2 in zip(all_logits, second_all_logits))
)
if __name__ == "__main__":
unittest.main()