More cache improvements (#1015)

* fix rotating kv cache for chat use case

* reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

* nit in chat

* fix tests

* fix tests

* fix tests

* docs

* chat command

* comments + docs

* Define meta_state on all Cache implementations

* fixes + trim_prompt_cache api

* fix default model

---------

Co-authored-by: Angelos Katharopoulos <a_katharopoulos@apple.com>
This commit is contained in:
Awni Hannun
2024-10-07 20:45:51 -07:00
committed by GitHub
parent 9bc53fc210
commit fca087be49
43 changed files with 1151 additions and 691 deletions

View File

@@ -7,13 +7,13 @@ from typing import List, Literal, Optional
import mlx.core as mx
import mlx.nn as nn
from .base import BaseModelArgs
from .base import BaseModelArgs, create_attention_mask
from .cache import MambaCache, RotatingKVCache
@dataclass
class ModelArgs(BaseModelArgs):
model_type: str
hidden_size: int
attention_bias: bool
conv1d_width: int
hidden_size: int
@@ -36,59 +36,6 @@ class ModelArgs(BaseModelArgs):
self.block_types = self._block_types
def create_window_causal_mask(N: int, window_size: int):
inds = mx.arange(N)
linds = inds[:, None]
rinds = inds[None]
mask = (linds < rinds) | (linds > rinds + window_size)
return mask * -1e9
class RecurrentCache:
def __init__(self):
self._cache = (None, None)
def __getitem__(self, idx):
return self._cache[idx]
def update(self, conv_state, recurrent_state):
self._cache = (conv_state, recurrent_state)
def state(self):
return self._cache
class WindowKVCache:
def __init__(self, window_size):
self.keys = None
self.values = None
self.offset = 0
self.window_size = window_size
def update_and_fetch(self, keys, values):
# TODO consider using rotating buffer here
# especially for very long generations
def _update(x, v):
t = x.shape[2] - self.window_size
if t > 0:
x = x[..., t:, :]
return mx.concatenate([x, v], axis=2)
self.offset += keys.shape[2]
if self.keys is None:
self.keys = keys
self.values = values
else:
self.keys = _update(self.keys, keys)
self.values = _update(self.values, values)
return self.keys, self.values
def state(self):
return self.keys, self.values
class RMSNorm(nn.Module):
def __init__(self, dims: int, eps: float = 1e-5):
super().__init__()
@@ -136,31 +83,22 @@ class Conv1d(nn.Module):
kernel_size: int,
):
super().__init__()
self.weight = mx.zeros((kernel_size, channels))
self.weight = mx.zeros((channels, kernel_size, 1))
self.bias = mx.zeros((channels,))
def __call__(self, x, cache=None):
w = self.weight.T[..., None]
kw, groups = self.weight.shape
if cache is not None:
l = []
# Pad the cache if needed
if cache.shape[1] < kw - 1:
l.append(
mx.zeros(
(x.shape[0], kw - 1 - cache.shape[1], groups), dtype=x.dtype
)
)
l.extend([cache, x])
x = mx.concatenate(l, axis=1)
y = (x * w.swapaxes(0, 2)).sum(axis=1, keepdims=True)
else:
y = mx.conv_general(x, w, padding=([kw - 1], [0]), groups=groups)
B, L, C = x.shape
groups, K, _ = self.weight.shape
# The cache is always kw - 1
cache = x[:, max(x.shape[1] - kw + 1, 0) :, :]
if cache is not None:
x = mx.concatenate([cache, x], axis=1)
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=groups)
y = y + self.bias
return y, cache
return y, x[:, -K + 1 :, :]
class RGLRU(nn.Module):
@@ -269,19 +207,9 @@ class RecurrentBlock(nn.Module):
# x branch.
x = self.linear_x(x)
if cache is None:
conv_state, recurrent_state = (None, None)
else:
conv_state, recurrent_state = cache[0], cache[1]
x, conv_state = self.conv_1d(
x=x,
cache=conv_state,
)
x, recurrent_state = self.rg_lru(
x=x,
cache=recurrent_state,
)
if cache is not None:
cache.update(conv_state, recurrent_state)
cache = [None, None]
x, cache[0] = self.conv_1d(x=x, cache=cache[0])
x, cache[1] = self.rg_lru(x=x, cache=cache[1])
x = x * y
x = self.linear_out(x)
@@ -467,12 +395,14 @@ class Griffin(nn.Module):
if self.scale_by_sqrt_dim:
x = x * math.sqrt(x.shape[-1])
mask = None
if x.shape[1] > 1:
mask = create_window_causal_mask(
x.shape[1], self.config.attention_window_size
)
mask = mask.astype(x.dtype)
if cache is None:
cache = [None] * len(self.layers)
for i, block in enumerate(self.layers):
if block.temporal_block_type != "recurrent":
mask_cache = [cache[i]]
mask = create_attention_mask(x, mask_cache)
for i, block in enumerate(self.layers):
x = block(x, mask=mask, cache=cache[i])
@@ -485,6 +415,7 @@ class Model(nn.Module):
def __init__(self, config):
self.args = config
self.model = Griffin(config)
self.model_type = config.model_type
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
def __call__(self, tokens: mx.array, cache=None) -> mx.array:
@@ -508,10 +439,9 @@ class Model(nn.Module):
return self.model.layers
def sanitize(self, weights):
# Remove unused precomputed rotary freqs
for k, v in weights.items():
if "conv_1d.weight" in k and v.ndim == 3:
weights[k] = v.squeeze(1).T
weights[k] = v.moveaxis(2, 1)
if "lm_head.weight" not in weights:
self.pop("lm_head")
return weights
@@ -520,7 +450,7 @@ class Model(nn.Module):
cache = []
for layer in self.layers:
if layer.temporal_block_type == "recurrent":
cache.append(RecurrentCache())
cache.append(MambaCache())
else:
cache.append(WindowKVCache(self.args.attention_window_size))
cache.append(RotatingKVCache(max_size=self.args.attention_window_size))
return cache