Add rotating kvcache to save space

This commit is contained in:
N8 2024-12-14 17:08:06 -05:00
parent 406c7f300f
commit ac58a95fbd
3 changed files with 21 additions and 12 deletions

View File

@ -34,13 +34,15 @@ def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = Non
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
def create_attention_mask(h: mx.array, cache: Optional[Any] = None, reference_idx: Optional[int] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if reference_idx is not None:
c = cache[reference_idx]
if hasattr(c, "max_size"):
offset = min(c.max_size, c.offset)
window_size = c.max_size

View File

@ -6,8 +6,8 @@ from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs, create_attention_mask
from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention
from .cache import KVCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
@ -95,7 +95,6 @@ class Attention(nn.Module):
if cache is not None:
keys, values = cache.update_and_fetch(keys, values)
# Apply sliding window attention if enabled
if self.sliding_window is not None:
window_size = self.sliding_window
@ -104,8 +103,8 @@ class Attention(nn.Module):
if mask is not None:
mask = mask[..., -window_size:]
output = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=self.scale, mask=mask
output = scaled_dot_product_attention(
queries, keys, values, cache=cache, scale=self.scale, mask=mask
)
output = output.transpose(0, 2, 1, 3).reshape(B, L, -1)
@ -171,7 +170,7 @@ class CohereModel(nn.Module):
):
h = self.embed_tokens(inputs)
mask = create_attention_mask(h, cache)
mask = create_attention_mask(h, cache, reference_idx=self.args.sliding_window_pattern - 1)
if cache is None:
cache = [None] * len(self.layers)
@ -198,6 +197,15 @@ class Model(nn.Module):
out = self.model.embed_tokens.as_linear(out)
out = out * self.model.args.logit_scale
return out
def make_cache(self):
caches = []
for i in range(self.args.num_hidden_layers):
if i % self.args.sliding_window_pattern == self.args.sliding_window_pattern - 1:
caches.append(KVCache())
else:
caches.append(RotatingKVCache(max_size=self.args.sliding_window, keep=0))
return caches
@property
def layers(self):

View File

@ -187,11 +187,10 @@ def maybe_quantize_kv_cache(prompt_cache, quantized_kv_start, kv_group_size, kv_
and prompt_cache[0].offset > quantized_kv_start
):
for i in range(len(prompt_cache)):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
if isinstance(prompt_cache[i], cache.KVCache):
prompt_cache[i] = prompt_cache[i].to_quantized(
group_size=kv_group_size, bits=kv_bits
)
def generate_step(
prompt: mx.array,
model: nn.Module,