Define meta_state on all Cache implementations

This commit is contained in:
Angelos Katharopoulos 2024-10-07 15:46:48 -07:00
parent f6ff4f28b4
commit 7a3d0dd459

View File

@ -32,23 +32,21 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li
return [KVCache() for _ in range(num_layers)]
def save_prompt_cache(
file_name: str, cache: List[Any], metadata: Optional[Dict[str, str]] = None
):
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Optional[Dict[str, str]]): Optional metadata to save along
with model state..
metadata (Dict[str, str]): Optional metadata to save along with model
state.
"""
cache_data = [c.state for c in cache]
cache_info = [c.meta_state if hasattr(c, "meta_state") else "" for c in cache]
cache_info = [c.meta_state for c in cache]
cache_data = dict(tree_flatten(cache_data))
cache_classes = [type(c).__name__ for c in cache]
cache_metadata = [cache_classes, cache_info, metadata or ""]
cache_metadata = [cache_classes, cache_info, metadata]
cache_metadata = dict(tree_flatten(cache_metadata))
mx.save_safetensors(file_name, cache_data, cache_metadata)
@ -73,14 +71,33 @@ def load_prompt_cache(file_name, return_metadata=False):
cache = [globals()[c]() for c in classes]
for c, state, meta_state in zip(cache, arrays, info):
c.state = state
if hasattr(c, "meta_state"):
c.meta_state = meta_state
c.meta_state = meta_state
if return_metadata:
return cache, metadata
return cache
class KVCache:
class _BaseCache:
@property
def state(self):
return []
@state.setter
def state(self, v):
if v is not None and v:
raise ValueError("This cache has no state but a state was set.")
@property
def meta_state(self):
return {}
@state.setter
def meta_state(self, v):
if v is not None and v:
raise ValueError("This cache has no meta_state but a meta_state was set.")
class KVCache(_BaseCache):
def __init__(self):
self.keys = None
@ -128,7 +145,7 @@ class KVCache:
self.offset = self.keys.shape[2]
class RotatingKVCache:
class RotatingKVCache(_BaseCache):
def __init__(self, max_size=None, keep=0, step=256):
self.keep = keep
@ -259,7 +276,7 @@ class RotatingKVCache:
)
class MambaCache:
class MambaCache(_BaseCache):
def __init__(self):
self.cache = [None, None]