diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 49e3a198..8d5101d2 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -32,23 +32,21 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li return [KVCache() for _ in range(num_layers)] -def save_prompt_cache( - file_name: str, cache: List[Any], metadata: Optional[Dict[str, str]] = None -): +def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}): """ Save a pre-computed prompt cache to a file. Args: file_name (str): The ``.safetensors`` file name. cache (List[Any]): The model state. - metadata (Optional[Dict[str, str]]): Optional metadata to save along - with model state.. + metadata (Dict[str, str]): Optional metadata to save along with model + state. """ cache_data = [c.state for c in cache] - cache_info = [c.meta_state if hasattr(c, "meta_state") else "" for c in cache] + cache_info = [c.meta_state for c in cache] cache_data = dict(tree_flatten(cache_data)) cache_classes = [type(c).__name__ for c in cache] - cache_metadata = [cache_classes, cache_info, metadata or ""] + cache_metadata = [cache_classes, cache_info, metadata] cache_metadata = dict(tree_flatten(cache_metadata)) mx.save_safetensors(file_name, cache_data, cache_metadata) @@ -73,14 +71,33 @@ def load_prompt_cache(file_name, return_metadata=False): cache = [globals()[c]() for c in classes] for c, state, meta_state in zip(cache, arrays, info): c.state = state - if hasattr(c, "meta_state"): - c.meta_state = meta_state + c.meta_state = meta_state if return_metadata: return cache, metadata return cache -class KVCache: +class _BaseCache: + @property + def state(self): + return [] + + @state.setter + def state(self, v): + if v is not None and v: + raise ValueError("This cache has no state but a state was set.") + + @property + def meta_state(self): + return {} + + @state.setter + def meta_state(self, v): + if v is not None and v: + raise ValueError("This cache has no meta_state but a meta_state was set.") + + +class KVCache(_BaseCache): def __init__(self): self.keys = None @@ -128,7 +145,7 @@ class KVCache: self.offset = self.keys.shape[2] -class RotatingKVCache: +class RotatingKVCache(_BaseCache): def __init__(self, max_size=None, keep=0, step=256): self.keep = keep @@ -259,7 +276,7 @@ class RotatingKVCache: ) -class MambaCache: +class MambaCache(_BaseCache): def __init__(self): self.cache = [None, None]