mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Define meta_state on all Cache implementations
This commit is contained in:
parent
f6ff4f28b4
commit
7a3d0dd459
@ -32,23 +32,21 @@ def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> Li
|
||||
return [KVCache() for _ in range(num_layers)]
|
||||
|
||||
|
||||
def save_prompt_cache(
|
||||
file_name: str, cache: List[Any], metadata: Optional[Dict[str, str]] = None
|
||||
):
|
||||
def save_prompt_cache(file_name: str, cache: List[Any], metadata: Dict[str, str] = {}):
|
||||
"""
|
||||
Save a pre-computed prompt cache to a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
cache (List[Any]): The model state.
|
||||
metadata (Optional[Dict[str, str]]): Optional metadata to save along
|
||||
with model state..
|
||||
metadata (Dict[str, str]): Optional metadata to save along with model
|
||||
state.
|
||||
"""
|
||||
cache_data = [c.state for c in cache]
|
||||
cache_info = [c.meta_state if hasattr(c, "meta_state") else "" for c in cache]
|
||||
cache_info = [c.meta_state for c in cache]
|
||||
cache_data = dict(tree_flatten(cache_data))
|
||||
cache_classes = [type(c).__name__ for c in cache]
|
||||
cache_metadata = [cache_classes, cache_info, metadata or ""]
|
||||
cache_metadata = [cache_classes, cache_info, metadata]
|
||||
cache_metadata = dict(tree_flatten(cache_metadata))
|
||||
mx.save_safetensors(file_name, cache_data, cache_metadata)
|
||||
|
||||
@ -73,14 +71,33 @@ def load_prompt_cache(file_name, return_metadata=False):
|
||||
cache = [globals()[c]() for c in classes]
|
||||
for c, state, meta_state in zip(cache, arrays, info):
|
||||
c.state = state
|
||||
if hasattr(c, "meta_state"):
|
||||
c.meta_state = meta_state
|
||||
c.meta_state = meta_state
|
||||
if return_metadata:
|
||||
return cache, metadata
|
||||
return cache
|
||||
|
||||
|
||||
class KVCache:
|
||||
class _BaseCache:
|
||||
@property
|
||||
def state(self):
|
||||
return []
|
||||
|
||||
@state.setter
|
||||
def state(self, v):
|
||||
if v is not None and v:
|
||||
raise ValueError("This cache has no state but a state was set.")
|
||||
|
||||
@property
|
||||
def meta_state(self):
|
||||
return {}
|
||||
|
||||
@state.setter
|
||||
def meta_state(self, v):
|
||||
if v is not None and v:
|
||||
raise ValueError("This cache has no meta_state but a meta_state was set.")
|
||||
|
||||
|
||||
class KVCache(_BaseCache):
|
||||
|
||||
def __init__(self):
|
||||
self.keys = None
|
||||
@ -128,7 +145,7 @@ class KVCache:
|
||||
self.offset = self.keys.shape[2]
|
||||
|
||||
|
||||
class RotatingKVCache:
|
||||
class RotatingKVCache(_BaseCache):
|
||||
|
||||
def __init__(self, max_size=None, keep=0, step=256):
|
||||
self.keep = keep
|
||||
@ -259,7 +276,7 @@ class RotatingKVCache:
|
||||
)
|
||||
|
||||
|
||||
class MambaCache:
|
||||
class MambaCache(_BaseCache):
|
||||
def __init__(self):
|
||||
self.cache = [None, None]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user