# Copyright © 2023-2024 Apple Inc. from typing import Any, Dict, List, Optional import mlx.core as mx import mlx.nn as nn from mlx.utils import tree_flatten, tree_unflatten def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: if hasattr(model, "make_cache"): return model.make_cache() num_layers = len(model.layers) if max_kv_size is not None: return [ RotatingKVCache(max_size=max_kv_size, keep=4) for _ in range(num_layers) ] else: return [KVCache() for _ in range(num_layers)] def save_prompt_cache( file_name: str, cache: List[Any], metadata: Optional[Dict[str, str]] = None ): """ Save a pre-computed prompt cache to a file. """ cache_data, cache_info = zip(*(c.state for c in cache)) cache_data = dict(tree_flatten(cache_data)) cache_classes = [type(c).__name__ for c in cache] cache_metadata = [cache_classes, cache_info] if metadata: cache_metadata.append(metadata) cache_metadata = dict(tree_flatten(cache_metadata)) mx.save_safetensors(file_name, cache_data, cache_metadata) def load_prompt_cache(file_name, return_metadata=False): """ Load a prompt cache from a file. Args: file_name (str): The ``.safetensors`` file name. return_metadata (bool): Whether or not to return metadata. Default: ``False``. Returns: List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and the metadata if requested. """ arrays, cache_metadata = mx.load(file_name, return_metadata=True) arrays = tree_unflatten(list(arrays.items())) cache_metadata = tree_unflatten(list(cache_metadata.items())) classes, info = cache_metadata[:2] cache = [globals()[c]() for c in classes] for c, *state in zip(cache, arrays, info): c.state = state if return_metadata: return cache, cache_metadata[2] return cache class KVCache: def __init__(self): self.keys = None self.values = None self.offset = 0 self.step = 256 def update_and_fetch(self, keys, values): prev = self.offset if self.keys is None or (prev + keys.shape[2]) > self.keys.shape[2]: B, n_kv_heads, _, k_head_dim = keys.shape v_head_dim = values.shape[3] n_steps = (self.step + keys.shape[2] - 1) // self.step k_shape = (B, n_kv_heads, n_steps * self.step, k_head_dim) v_shape = (B, n_kv_heads, n_steps * self.step, v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: if prev % self.step != 0: self.keys = self.keys[..., :prev, :] self.values = self.values[..., :prev, :] self.keys = mx.concatenate([self.keys, new_k], axis=2) self.values = mx.concatenate([self.values, new_v], axis=2) else: self.keys, self.values = new_k, new_v self.offset += keys.shape[2] self.keys[..., prev : self.offset, :] = keys self.values[..., prev : self.offset, :] = values return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] @property def state(self): if self.offset == self.keys.shape[2]: return (self.keys, self.values), "" else: return ( self.keys[..., : self.offset, :], self.values[..., : self.offset, :], ), "" @state.setter def state(self, v): self.keys, self.values = v[0] self.offset = self.keys.shape[2] class RotatingKVCache: def __init__(self, max_size=None, keep=0, step=256): self.keep = keep self.keys = None self.values = None self.offset = 0 self.max_size = max_size self.step = step self._idx = 0 def _trim(self, trim_size, v, append=None): to_cat = [] if trim_size > 0: to_cat = [v[..., : self.keep, :], v[..., trim_size + self.keep :, :]] else: to_cat = [v] if append is not None: to_cat.append(append) return mx.concatenate(to_cat, axis=2) def _temporal_order(self, v): """ Rearrange the cache into temporal order, slicing off the end if unused. """ if self._idx == v.shape[2]: return v elif self._idx < self.offset: return mx.concatenate( [ v[..., : self.keep, :], v[..., self._idx :, :], v[..., self.keep : self._idx, :], ], axis=2, ) else: return v[..., : self._idx, :] def _update_concat(self, keys, values): if self.keys is None: self.keys = keys self.values = values else: # Put the keys/values in temporal order to # preserve context self.keys = self._temporal_order(self.keys) self.values = self._temporal_order(self.values) # The largest size is self.max_size + S - 1 to ensure # every token gets at least self.max_size context trim_size = self._idx - self.max_size + 1 self.keys = self._trim(trim_size, self.keys, keys) self.values = self._trim(trim_size, self.values, values) self.offset += keys.shape[2] self._idx = self.keys.shape[2] return self.keys, self.values def _update_in_place(self, keys, values): # May not have hit the max size yet, so potentially # keep growing the cache B, n_kv_heads, S, k_head_dim = keys.shape prev = self.offset if self.keys is None or ( prev >= self.keys.shape[2] and self.keys.shape[2] < self.max_size ): v_head_dim = values.shape[3] new_size = min(self.step, self.max_size - prev) k_shape = (B, n_kv_heads, new_size, k_head_dim) v_shape = (B, n_kv_heads, new_size, v_head_dim) new_k = mx.zeros(k_shape, keys.dtype) new_v = mx.zeros(v_shape, values.dtype) if self.keys is not None: self.keys = mx.concatenate([self.keys, new_k], axis=2) self.values = mx.concatenate([self.values, new_v], axis=2) else: self.keys, self.values = new_k, new_v self._idx = prev # Trim if needed trim_size = self.keys.shape[2] - self.max_size if trim_size > 0: self.keys = self._trim(trim_size, self.keys) self.values = self._trim(trim_size, self.values) self._idx = self.max_size # Rotate if self._idx == self.max_size: self._idx = self.keep # Assign self.keys[..., self._idx : self._idx + S, :] = keys self.values[..., self._idx : self._idx + S, :] = values self.offset += S self._idx += S # If the buffer is not full, slice off the end if self.offset < self.max_size: return self.keys[..., : self.offset, :], self.values[..., : self.offset, :] return self.keys, self.values def update_and_fetch(self, keys, values): if keys.shape[2] == 1: return self._update_in_place(keys, values) return self._update_concat(keys, values) @property def state(self): if self.offset < self.keys.shape[2]: kv_state = (self.keys[..., : self.offset], self.values[..., : self.offset]) else: kv_state = (self.keys, self.values) extra_state = tuple( map(str, (self.keep, self.max_size, self.step, self.offset, self._idx)) ) return kv_state, extra_state @state.setter def state(self, v): self.keys, self.values = v[0] self.keep, self.max_size, self.step, self.offset, self._idx = map( int, v[1], ) class MambaCache: def __init__(self): self.cache = [None, None] def __setitem__(self, idx, value): self.cache[idx] = value def __getitem__(self, idx): return self.cache[idx] @property def state(self): return self.cache @property def state(self): return self.cache, "" @state.setter def state(self, v): self.cache = v[0]