From ac58a95fbd48c66a42af810bc00d5bdbab737e40 Mon Sep 17 00:00:00 2001 From: N8 Date: Sat, 14 Dec 2024 17:08:06 -0500 Subject: [PATCH] Add rotating kvcache to save space --- llms/mlx_lm/models/base.py | 4 +++- llms/mlx_lm/models/cohere2.py | 20 ++++++++++++++------ llms/mlx_lm/utils.py | 9 ++++----- 3 files changed, 21 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/models/base.py b/llms/mlx_lm/models/base.py index f02f49b1..3b5ddcb0 100644 --- a/llms/mlx_lm/models/base.py +++ b/llms/mlx_lm/models/base.py @@ -34,13 +34,15 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non return mask * -1e9 -def create_attention_mask(h: mx.array, cache: Optional[Any] = None): +def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_idx: Optional[int] = None): T = h.shape[1] if T > 1: window_size = None offset = 0 if cache is not None and cache[0] is not None: c = cache[0] + if reference_idx is not None: + c = cache[reference_idx] if hasattr(c, "max_size"): offset = min(c.max_size, c.offset) window_size = c.max_size diff --git a/llms/mlx_lm/models/cohere2.py b/llms/mlx_lm/models/cohere2.py index a078409b..a2854d19 100644 --- a/llms/mlx_lm/models/cohere2.py +++ b/llms/mlx_lm/models/cohere2.py @@ -6,8 +6,8 @@ from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn -from .base import BaseModelArgs, create_attention_mask - +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention +from .cache import KVCache, RotatingKVCache @dataclass class ModelArgs(BaseModelArgs): @@ -95,7 +95,6 @@ class Attention(nn.Module): if cache is not None: keys, values = cache.update_and_fetch(keys, values) - # Apply sliding window attention if enabled if self.sliding_window is not None: window_size = self.sliding_window @@ -104,8 +103,8 @@ class Attention(nn.Module): if mask is not None: mask = mask[..., -window_size:] - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask ) output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -171,7 +170,7 @@ class CohereModel(nn.Module): ): h = self.embed_tokens(inputs) - mask = create_attention_mask(h, cache) + mask = create_attention_mask(h, cache, reference_idx=self.args.sliding_window_pattern - 1) if cache is None: cache = [None] * len(self.layers) @@ -198,6 +197,15 @@ class Model(nn.Module): out = self.model.embed_tokens.as_linear(out) out = out * self.model.args.logit_scale return out + + def make_cache(self): + caches = [] + for i in range(self.args.num_hidden_layers): + if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1: + caches.append(KVCache()) + else: + caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0)) + return caches @property def layers(self): diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index b87f5a24..10292d75 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -187,11 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_ and prompt_cache[0].offset > quantized_kv_start ): for i in range(len(prompt_cache)): - prompt_cache[i] = prompt_cache[i].to_quantized( - group_size=kv_group_size, bits=kv_bits - ) - - + if isinstance(prompt_cache[i], cache.KVCache): + prompt_cache[i] = prompt_cache[i].to_quantized( + group_size=kv_group_size, bits=kv_bits + ) def generate_step( prompt: mx.array, model: nn.Module,