reorg + fixes to caching, unify prompt caching across types and use cases for e.g. caching during a chat

This commit is contained in:
Awni Hannun
2024-10-05 14:49:39 -07:00
parent ed060a7c5c
commit 782f5a71b7
40 changed files with 824 additions and 691 deletions

View File

@@ -2,153 +2,9 @@
import inspect
from dataclasses import dataclass
from typing import Any, List, Optional
from typing import Any, Optional
import mlx.core as mx
import mlx.nn as nn
class KVCache:
def __init__(self, head_dim, n_kv_heads):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keys = None
self.values = None
self.offset = 0
self.step = 256
def update_and_fetch(self, keys, values):
prev = self.offset
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
B = keys.shape[0]
n_steps = (self.step + keys.shape[2] - 1) // self.step
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
if prev % self.step != 0:
self.keys = self.keys[..., :prev, :]
self.values = self.values[..., :prev, :]
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self.offset += keys.shape[2]
self.keys[..., prev : self.offset, :] = keys
self.values[..., prev : self.offset, :] = values
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
@property
def state(self):
return self.keys, self.values
class RotatingKVCache:
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
self.n_kv_heads = n_kv_heads
if isinstance(head_dim, int):
self.k_head_dim = self.v_head_dim = head_dim
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
self.k_head_dim, self.v_head_dim = head_dim
else:
raise ValueError("head_dim must be an int or a tuple of two ints")
self.keep = keep
self.keys = None
self.values = None
self.offset = 0
self.max_size = max_size
self.step = step
self._idx = 0
def _trim(self, trim_size, v, append=None):
to_cat = []
if trim_size > 0:
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
else:
to_cat = [v]
if append is not None:
to_cat.append(append)
return mx.concatenate(to_cat, axis=2)
def _update_concat(self, keys, values):
if self.keys is None:
self.keys = keys
self.values = values
else:
if self._idx < self.keys.shape[2]:
self.keys = self.keys[..., : self._idx, :]
self.values = self.values[..., : self._idx, :]
# The largest size is self.max_size + S - 1 to ensure
# every token gets at least self.max_size context
trim_size = self._idx - self.max_size + 1
self.keys = self._trim(trim_size, self.keys, keys)
self.values = self._trim(trim_size, self.values, values)
self.offset += keys.shape[2]
self._idx = self.keys.shape[2]
return self.keys, self.values
def _update_in_place(self, keys, values):
# May not have hit the max size yet, so potentially
# keep growing the cache
B, _, S = keys.shape[:3]
prev = self.offset
if self.keys is None or (
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
):
new_size = min(self.step, self.max_size - prev)
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
new_k = mx.zeros(k_shape, keys.dtype)
new_v = mx.zeros(v_shape, values.dtype)
if self.keys is not None:
self.keys = mx.concatenate([self.keys, new_k], axis=2)
self.values = mx.concatenate([self.values, new_v], axis=2)
else:
self.keys, self.values = new_k, new_v
self._idx = prev
# Trim if needed
trim_size = self.keys.shape[2] - self.max_size
if trim_size > 0:
self.keys = self._trim(trim_size, self.keys)
self.values = self._trim(trim_size, self.values)
self._idx = self.max_size
# Rotate
if self._idx == self.max_size:
self._idx = self.keep
# Assign
self.keys[..., self._idx : self._idx + S, :] = keys
self.values[..., self._idx : self._idx + S, :] = values
self.offset += S
self._idx += S
# If the buffer is not full, slice off the end
if self.offset < self.max_size:
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
return self.keys, self.values
def update_and_fetch(self, keys, values):
S = keys.shape[2]
if S == 1 or (self.keys is not None and S < (self.keys.shape[2] - self._idx)):
return self._update_in_place(keys, values)
return self._update_concat(keys, values)
@property
def state(self):
return self.keys, self.values
@dataclass
@@ -164,25 +20,30 @@ class BaseModelArgs:
)
def create_additive_causal_mask(N: int, offset: int = 0):
def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None):
rinds = mx.arange(offset + N)
linds = mx.arange(offset, offset + N) if offset else rinds
mask = linds[:, None] < rinds[None]
linds = linds[:, None]
rinds = rinds[None]
mask = linds < rinds
if window_size is not None:
mask = mask | (linds > rinds + window_size)
return mask * -1e9
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
T = h.shape[1]
if T > 1:
window_size = None
offset = 0
if cache is not None and cache[0] is not None:
c = cache[0]
if isinstance(c, RotatingKVCache):
if hasattr(c, "max_size"):
offset = min(c.max_size - 1, c.offset)
window_size = c.max_size
else:
offset = c.offset
else:
offset = 0
mask = create_additive_causal_mask(T, offset)
mask = create_causal_mask(T, offset, window_size=window_size)
mask = mask.astype(h.dtype)
else:
mask = None