mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
180 lines
6.1 KiB
Python
180 lines
6.1 KiB
Python
# Copyright © 2023-2024 Apple Inc.
|
|
|
|
import inspect
|
|
from dataclasses import dataclass
|
|
from typing import Any, List, Optional
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
|
|
|
|
class KVCache:
|
|
|
|
def __init__(self, head_dim, n_kv_heads):
|
|
self.n_kv_heads = n_kv_heads
|
|
if isinstance(head_dim, int):
|
|
self.k_head_dim = self.v_head_dim = head_dim
|
|
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
|
self.k_head_dim, self.v_head_dim = head_dim
|
|
else:
|
|
raise ValueError("head_dim must be an int or a tuple of two ints")
|
|
self.keys = None
|
|
self.values = None
|
|
self.offset = 0
|
|
self.step = 256
|
|
|
|
def update_and_fetch(self, keys, values):
|
|
prev = self.offset
|
|
if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]:
|
|
B = keys.shape[0]
|
|
n_steps = (self.step + keys.shape[2] - 1) // self.step
|
|
k_shape = (B, self.n_kv_heads, n_steps * self.step, self.k_head_dim)
|
|
v_shape = (B, self.n_kv_heads, n_steps * self.step, self.v_head_dim)
|
|
new_k = mx.zeros(k_shape, keys.dtype)
|
|
new_v = mx.zeros(v_shape, values.dtype)
|
|
if self.keys is not None:
|
|
if prev % self.step != 0:
|
|
self.keys = self.keys[..., :prev, :]
|
|
self.values = self.values[..., :prev, :]
|
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
|
else:
|
|
self.keys, self.values = new_k, new_v
|
|
|
|
self.offset += keys.shape[2]
|
|
self.keys[..., prev : self.offset, :] = keys
|
|
self.values[..., prev : self.offset, :] = values
|
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
|
|
|
def state(self):
|
|
return self.keys, self.values
|
|
|
|
|
|
class RotatingKVCache:
|
|
|
|
def __init__(self, head_dim, n_kv_heads, max_size, keep=0, step=256):
|
|
self.n_kv_heads = n_kv_heads
|
|
if isinstance(head_dim, int):
|
|
self.k_head_dim = self.v_head_dim = head_dim
|
|
elif isinstance(head_dim, tuple) and len(head_dim) == 2:
|
|
self.k_head_dim, self.v_head_dim = head_dim
|
|
else:
|
|
raise ValueError("head_dim must be an int or a tuple of two ints")
|
|
self.keep = keep
|
|
self.keys = None
|
|
self.values = None
|
|
self.offset = 0
|
|
self.max_size = max_size
|
|
self.step = step
|
|
self._idx = 0
|
|
|
|
def _trim(self, trim_size, v, append=None):
|
|
to_cat = []
|
|
if trim_size > 0:
|
|
to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]]
|
|
else:
|
|
to_cat = [v]
|
|
if append is not None:
|
|
to_cat.append(append)
|
|
return mx.concatenate(to_cat, axis=2)
|
|
|
|
def update_and_fetch(self, keys, values):
|
|
prev = self.offset
|
|
B, _, S = keys.shape[:3]
|
|
|
|
# Prefill mode
|
|
if S > 1:
|
|
if self.keys is None:
|
|
self.keys = keys
|
|
self.values = values
|
|
else:
|
|
# The largest size is self.max_size + S - 1 to ensure
|
|
# every token gets at least self.max_size context
|
|
trim_size = self.keys.shape[2] - self.max_size + 1
|
|
self.keys = self._trim(trim_size, self.keys, keys)
|
|
self.values = self._trim(trim_size, self.values, values)
|
|
self.offset += S
|
|
self._idx = self.keys.shape[2]
|
|
return self.keys, self.values
|
|
|
|
# Generation mode
|
|
# May not have hit the max size yet, so potentially
|
|
# keep growing the cache
|
|
if self.keys is None or (
|
|
prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size
|
|
):
|
|
new_size = min(self.step, self.max_size - prev)
|
|
k_shape = (B, self.n_kv_heads, new_size, self.k_head_dim)
|
|
v_shape = (B, self.n_kv_heads, new_size, self.v_head_dim)
|
|
new_k = mx.zeros(k_shape, keys.dtype)
|
|
new_v = mx.zeros(v_shape, values.dtype)
|
|
if self.keys is not None:
|
|
self.keys = mx.concatenate([self.keys, new_k], axis=2)
|
|
self.values = mx.concatenate([self.values, new_v], axis=2)
|
|
else:
|
|
self.keys, self.values = new_k, new_v
|
|
self._idx = prev
|
|
|
|
# Trim if needed
|
|
trim_size = self.keys.shape[2] - self.max_size
|
|
if trim_size > 0:
|
|
self.keys = self._trim(trim_size, self.keys)
|
|
self.values = self._trim(trim_size, self.values)
|
|
self._idx = self.max_size
|
|
|
|
# Rotate
|
|
if self._idx == self.max_size:
|
|
self._idx = self.keep
|
|
|
|
# Assign
|
|
self.keys[..., self._idx : self._idx + 1, :] = keys
|
|
self.values[..., self._idx : self._idx + 1, :] = values
|
|
self.offset += 1
|
|
self._idx += 1
|
|
|
|
# If the buffer is not full, slice off the end
|
|
if self.offset < self.max_size:
|
|
return self.keys[..., : self.offset, :], self.values[..., : self.offset, :]
|
|
return self.keys, self.values
|
|
|
|
def state(self):
|
|
return self.keys, self.values
|
|
|
|
|
|
@dataclass
|
|
class BaseModelArgs:
|
|
@classmethod
|
|
def from_dict(cls, params):
|
|
return cls(
|
|
**{
|
|
k: v
|
|
for k, v in params.items()
|
|
if k in inspect.signature(cls).parameters
|
|
}
|
|
)
|
|
|
|
|
|
def create_additive_causal_mask(N: int, offset: int = 0):
|
|
rinds = mx.arange(offset + N)
|
|
linds = mx.arange(offset, offset + N) if offset else rinds
|
|
mask = linds[:, None] < rinds[None]
|
|
return mask * -1e9
|
|
|
|
|
|
def create_attention_mask(h: mx.array, cache: Optional[Any] = None):
|
|
T = h.shape[1]
|
|
if T > 1:
|
|
if cache is not None and cache[0] is not None:
|
|
c = cache[0]
|
|
if isinstance(c, RotatingKVCache):
|
|
offset = min(c.max_size - 1, c.offset)
|
|
else:
|
|
offset = c.offset
|
|
else:
|
|
offset = 0
|
|
mask = create_additive_causal_mask(T, offset)
|
|
mask = mask.astype(h.dtype)
|
|
else:
|
|
mask = None
|
|
return mask
|