diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index c003940b..6707d25c 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -76,7 +76,12 @@ def setup_arg_parser(): type=int, default=None, help="Set the MLX cache limit in GB", - required=False, + ) + parser.add_argument( + "--max-kv-size", + type=int, + default=1024, + help="Set the maximum key-value cache size", ) return parser @@ -154,6 +159,7 @@ def main(): formatter=formatter, temp=args.temp, top_p=args.top_p, + max_kv_size=args.max_kv_size, ) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index 3fe276d2..3e84554c 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -1,6 +1,8 @@ +# Copyright © 2023-2024 Apple Inc. + import inspect from dataclasses import dataclass -from typing import List, Optional +from typing import Any, List, Optional import mlx.core as mx import mlx.nn as nn @@ -44,6 +46,100 @@ class KVCache: self.values[..., prev : self.offset, :] = values return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + def state(self): + return self.keys, self.values + + +class RotatingKVCache: + + def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256): + self.n_kv_heads = n_kv_heads + if isinstance(head_dim, int): + self.k_head_dim = self.v_head_dim = head_dim + elif isinstance(head_dim, tuple) and len(head_dim) == 2: + self.k_head_dim, self.v_head_dim = head_dim + else: + raise ValueError("head_dim must be an int or a tuple of two ints") + self.keep = keep + self.keys = None + self.values = None + self.offset = 0 + self.max_size = max_size + self.step = step + self._idx = 0 + + def _trim(self, trim_size, v, append=None): + to_cat = [] + if trim_size > 0: + to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] + else: + to_cat = [v] + if append is not None: + to_cat.append(append) + return mx.concatenate(to_cat, axis=2) + + def update_and_fetch(self, keys, values): + prev = self.offset + B, _, S = keys.shape[:3] + + # Prefill mode + if S > 1: + if self.keys is None: + self.keys = keys + self.values = values + else: + # The largest size is self.max_size + S - 1 to ensure + # every token gets at least self.max_size context + trim_size = self.keys.shape[2] - self.max_size + 1 + self.keys = self._trim(trim_size, self.keys, keys) + self.values = self._trim(trim_size, self.values, values) + self.offset += S + self._idx = self.keys.shape[2] + return self.keys, self.values + + # Generation mode + # May not have hit the max size yet, so potentially + # keep growing the cache + if self.keys is None or ( + prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size + ): + new_size = min(self.step, self.max_size - prev) + k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim) + v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim) + new_k = mx.zeros(k_shape, keys.dtype) + new_v = mx.zeros(v_shape, values.dtype) + if self.keys is not None: + self.keys = mx.concatenate([self.keys, new_k], axis=2) + self.values = mx.concatenate([self.values, new_v], axis=2) + else: + self.keys, self.values = new_k, new_v + self._idx = prev + + # Trim if needed + trim_size = self.keys.shape[2] - self.max_size + if trim_size > 0: + self.keys = self._trim(trim_size, self.keys) + self.values = self._trim(trim_size, self.values) + self._idx = self.max_size + + # Rotate + if self._idx == self.max_size: + self._idx = self.keep + + # Assign + self.keys[..., self._idx : self._idx + 1, :] = keys + self.values[..., self._idx : self._idx + 1, :] = values + self.offset += 1 + self._idx += 1 + + # If the buffer is not full, slice off the end + if self.offset < self.max_size: + return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] + return self.keys, self.values + + def state(self): + return self.keys, self.values + @dataclass class BaseModelArgs: @@ -65,13 +161,17 @@ def create_additive_causal_mask(N: int, offset: int = 0): return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[List[KVCache]] = None): +def create_attention_mask(h: mx.array, cache: Optional[Any] = None): T = h.shape[1] if T > 1: - # Input consists of multiple tokens, create a causal mask so that prior - # tokens do not give attention to later tokens. If a cache is in place - # (because e.g. prompt reuse), offset the mask accordingly. - offset = cache[0].offset if cache is not None and cache[0] is not None else 0 + if cache is not None and cache[0] is not None: + c = cache[0] + if isinstance(c, RotatingKVCache): + offset = min(c.max_size - 1, c.offset) + else: + offset = c.offset + else: + offset = 0 mask = create_additive_causal_mask(T, offset) mask = mask.astype(h.dtype) else: diff --git a/llms/mlx_lm/models/cohere.py b/llms/mlx_lm/models/cohere.py index 7dc2b9bf..cfcf2945 100644 --- a/llms/mlx_lm/models/cohere.py +++ b/llms/mlx_lm/models/cohere.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Optional, Tuple diff --git a/llms/mlx_lm/models/dbrx.py b/llms/mlx_lm/models/dbrx.py index 7a2a7a7d..f0214549 100644 --- a/llms/mlx_lm/models/dbrx.py +++ b/llms/mlx_lm/models/dbrx.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Optional, Tuple diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index bd743e53..f320b564 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math from dataclasses import dataclass from typing import Dict, Optional, Tuple diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 323ebaa6..c6150284 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Optional, Tuple diff --git a/llms/mlx_lm/models/gemma2.py b/llms/mlx_lm/models/gemma2.py index d4bd8a5d..1d410a15 100644 --- a/llms/mlx_lm/models/gemma2.py +++ b/llms/mlx_lm/models/gemma2.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Optional, Tuple diff --git a/llms/mlx_lm/models/gpt2.py b/llms/mlx_lm/models/gpt2.py index 81f71cac..8a770936 100644 --- a/llms/mlx_lm/models/gpt2.py +++ b/llms/mlx_lm/models/gpt2.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/gpt_bigcode.py b/llms/mlx_lm/models/gpt_bigcode.py index a5336203..652eb9e4 100644 --- a/llms/mlx_lm/models/gpt_bigcode.py +++ b/llms/mlx_lm/models/gpt_bigcode.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/gpt_neox.py b/llms/mlx_lm/models/gpt_neox.py index 1d2f74b7..c2aaa9ea 100644 --- a/llms/mlx_lm/models/gpt_neox.py +++ b/llms/mlx_lm/models/gpt_neox.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/internlm2.py b/llms/mlx_lm/models/internlm2.py index 2ee2af2d..bcc0cf0c 100644 --- a/llms/mlx_lm/models/internlm2.py +++ b/llms/mlx_lm/models/internlm2.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/llama.py b/llms/mlx_lm/models/llama.py index 2f323245..192e591f 100644 --- a/llms/mlx_lm/models/llama.py +++ b/llms/mlx_lm/models/llama.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index a3d01cbb..df0670be 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/mixtral.py b/llms/mlx_lm/models/mixtral.py index c7d8c5c5..2db57752 100644 --- a/llms/mlx_lm/models/mixtral.py +++ b/llms/mlx_lm/models/mixtral.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index 8a28ad74..59849c96 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from sys import exit from typing import Optional, Tuple diff --git a/llms/mlx_lm/models/openelm.py b/llms/mlx_lm/models/openelm.py index 3f0d2605..19d3c027 100644 --- a/llms/mlx_lm/models/openelm.py +++ b/llms/mlx_lm/models/openelm.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/phi.py b/llms/mlx_lm/models/phi.py index 520ac1ad..fd3fd709 100644 --- a/llms/mlx_lm/models/phi.py +++ b/llms/mlx_lm/models/phi.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math from dataclasses import dataclass from typing import Tuple diff --git a/llms/mlx_lm/models/phi3.py b/llms/mlx_lm/models/phi3.py index 2536aacb..f8facdb1 100644 --- a/llms/mlx_lm/models/phi3.py +++ b/llms/mlx_lm/models/phi3.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, List, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/phi3small.py b/llms/mlx_lm/models/phi3small.py index de075652..665dbc73 100644 --- a/llms/mlx_lm/models/phi3small.py +++ b/llms/mlx_lm/models/phi3small.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math from dataclasses import dataclass from functools import partial diff --git a/llms/mlx_lm/models/phixtral.py b/llms/mlx_lm/models/phixtral.py index f0aef0c9..bb67615d 100644 --- a/llms/mlx_lm/models/phixtral.py +++ b/llms/mlx_lm/models/phixtral.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import inspect import math from dataclasses import dataclass diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index 47a9ea4f..5d2b7586 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Any, List, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/qwen.py b/llms/mlx_lm/models/qwen.py index 67816599..6d2c7bbf 100644 --- a/llms/mlx_lm/models/qwen.py +++ b/llms/mlx_lm/models/qwen.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Tuple diff --git a/llms/mlx_lm/models/qwen2.py b/llms/mlx_lm/models/qwen2.py index cb8268aa..b3ce02a3 100644 --- a/llms/mlx_lm/models/qwen2.py +++ b/llms/mlx_lm/models/qwen2.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/qwen2_moe.py b/llms/mlx_lm/models/qwen2_moe.py index 121ab813..ff7831f3 100644 --- a/llms/mlx_lm/models/qwen2_moe.py +++ b/llms/mlx_lm/models/qwen2_moe.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math from dataclasses import dataclass from typing import Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/models/recurrent_gemma.py b/llms/mlx_lm/models/recurrent_gemma.py index 428431e3..34750ace 100644 --- a/llms/mlx_lm/models/recurrent_gemma.py +++ b/llms/mlx_lm/models/recurrent_gemma.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math from dataclasses import dataclass from typing import List, Literal, Optional @@ -53,6 +55,9 @@ class RecurrentCache: def update(self, conv_state, recurrent_state): self._cache = (conv_state, recurrent_state) + def state(self): + return self._cache + class WindowKVCache: @@ -80,6 +85,9 @@ class WindowKVCache: self.values = _update(self.values, values) return self.keys, self.values + def state(self): + return self.keys, self.values + class RMSNorm(nn.Module): def __init__(self, dims: int, eps: float = 1e-5): diff --git a/llms/mlx_lm/models/stablelm.py b/llms/mlx_lm/models/stablelm.py index 9b4d043c..b340de28 100644 --- a/llms/mlx_lm/models/stablelm.py +++ b/llms/mlx_lm/models/stablelm.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math from dataclasses import dataclass from typing import Tuple diff --git a/llms/mlx_lm/models/starcoder2.py b/llms/mlx_lm/models/starcoder2.py index a6eb5377..9cec0e39 100644 --- a/llms/mlx_lm/models/starcoder2.py +++ b/llms/mlx_lm/models/starcoder2.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + from dataclasses import dataclass from typing import Optional, Tuple diff --git a/llms/mlx_lm/models/su_rope.py b/llms/mlx_lm/models/su_rope.py index cdf6ceaf..2ee20a63 100644 --- a/llms/mlx_lm/models/su_rope.py +++ b/llms/mlx_lm/models/su_rope.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math from typing import List, Union diff --git a/llms/mlx_lm/models/switch_layers.py b/llms/mlx_lm/models/switch_layers.py index 00aa65d8..4a157473 100644 --- a/llms/mlx_lm/models/switch_layers.py +++ b/llms/mlx_lm/models/switch_layers.py @@ -1,3 +1,5 @@ +# Copyright © 2023-2024 Apple Inc. + import math import mlx.core as mx diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index e7a9dba8..44196766 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -19,7 +19,7 @@ from mlx.utils import tree_flatten from transformers import PreTrainedTokenizer # Local imports -from .models.base import KVCache +from .models.base import KVCache, RotatingKVCache from .sample_utils import categorical_sampling, min_p_sampling, top_p_sampling from .tokenizer_utils import TokenizerWrapper, load_tokenizer from .tuner.utils import apply_lora_layers @@ -136,6 +136,8 @@ def generate_step( min_p: float = 0.0, min_tokens_to_keep: int = 1, logit_bias: Optional[Dict[int, float]] = None, + prefill_step_size: int = 512, + max_kv_size: Optional[int] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: """ A generator producing token ids based on the given prompt from the model. @@ -156,6 +158,9 @@ def generate_step( min_tokens_to_keep (int, optional): Minimum number of tokens that cannot be filtered by min_p sampling. logit_bias (dictionary, optional): Additive logit bias. + prefill_step_size (int): Step size for processing the prompt. + max_kv_size (int, optional): Maximum size of the key-value cache. Old + entries (except the first 4 tokens) will be overwritten. Yields: Generator[Tuple[mx.array, mx.array], None, None]: A generator producing @@ -197,7 +202,13 @@ def generate_step( if isinstance(model.n_kv_heads, int) else model.n_kv_heads ) - cache = [KVCache(model.head_dim, n) for n in kv_heads] + if max_kv_size is not None: + cache = [ + RotatingKVCache(model.head_dim, n, max_size=max_kv_size, keep=4) + for n in kv_heads + ] + else: + cache = [KVCache(model.head_dim, n) for n in kv_heads] repetition_context = prompt.tolist() @@ -223,6 +234,11 @@ def generate_step( repetition_context = repetition_context[-repetition_context_size:] return y, logprobs.squeeze(0) + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=cache) + mx.eval([c.state for c in cache]) + y = y[prefill_step_size:] + y, logprobs = _step(y) mx.async_eval(y) @@ -343,8 +359,10 @@ def generate( return prompt_tps = prompt_tokens.size / prompt_time gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {gen_tps:.3f} tokens-per-sec") + print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 2**30 + print(f"Peak memory: {peak_mem:.3f} GB") return detokenizer.text diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 40b73ede..f73aaa0a 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.16.0" +__version__ = "0.17.0" diff --git a/llms/tests/test_models.py b/llms/tests/test_models.py index 19341981..fcf1dc33 100644 --- a/llms/tests/test_models.py +++ b/llms/tests/test_models.py @@ -4,7 +4,7 @@ import unittest import mlx.core as mx from mlx.utils import tree_map -from mlx_lm.models.base import KVCache +from mlx_lm.models.base import KVCache, RotatingKVCache class TestModels(unittest.TestCase): @@ -29,6 +29,64 @@ class TestModels(unittest.TestCase): self.assertTrue(mx.array_equal(v_up, expected)) self.assertEqual(cache.offset, cache.step + 1) + def test_rotating_kv_cache(self): + b, h, d = 1, 2, 32 + cache = RotatingKVCache(d, h, max_size=8, step=4) + + k = mx.random.uniform(shape=(b, h, 2, d)) + v = mx.random.uniform(shape=(b, h, 2, d)) + + k_up, v_up = cache.update_and_fetch(k, v) + self.assertTrue(mx.array_equal(k_up, k)) + self.assertTrue(mx.array_equal(v_up, v)) + self.assertEqual(cache.offset, 2) + + k = mx.random.uniform(shape=(b, h, 5, d)) + v = mx.random.uniform(shape=(b, h, 5, d)) + k_up, v_up = cache.update_and_fetch(k, v) + self.assertTrue(mx.array_equal(k_up[..., 2:, :], k)) + self.assertTrue(mx.array_equal(v_up[..., 2:, :], v)) + + k = mx.random.uniform(shape=(b, h, 4, d)) + v = mx.random.uniform(shape=(b, h, 4, d)) + k_up, v_up = cache.update_and_fetch(k, v) + self.assertTrue(mx.array_equal(k_up[..., -4:, :], k)) + self.assertTrue(mx.array_equal(v_up[..., -4:, :], v)) + + idx = 0 + for _ in range(10): + k = mx.random.uniform(shape=(b, h, 1, d)) + v = mx.random.uniform(shape=(b, h, 1, d)) + k_up, v_up = cache.update_and_fetch(k, v) + self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k)) + self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v)) + idx += 1 + idx %= 8 + + # Try with nonzero keep + cache = RotatingKVCache(d, h, max_size=8, step=4, keep=2) + + # Check a large update + k = mx.random.uniform(shape=(b, h, 20, d)) + v = mx.random.uniform(shape=(b, h, 20, d)) + k_up, v_up = cache.update_and_fetch(k, v) + self.assertTrue(mx.array_equal(k_up, k)) + self.assertTrue(mx.array_equal(v_up, v)) + + # A bunch of small updates + self.assertEqual(cache.offset, 20) + idx = 2 + for i in range(10): + k = mx.random.uniform(shape=(b, h, 1, d)) + v = mx.random.uniform(shape=(b, h, 1, d)) + k_up, v_up = cache.update_and_fetch(k, v) + self.assertTrue(mx.array_equal(k_up[..., idx : idx + 1, :], k)) + self.assertTrue(mx.array_equal(v_up[..., idx : idx + 1, :], v)) + self.assertEqual(cache.offset, 21 + i) + idx += 1 + if idx >= 8: + idx = 2 + def model_test_runner(self, model, model_type, vocab_size, num_layers): self.assertEqual(len(model.layers), num_layers)