mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
fix tests
This commit is contained in:
@@ -8,6 +8,18 @@ from mlx.utils import tree_flatten, tree_unflatten
|
||||
|
||||
|
||||
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
|
||||
"""
|
||||
Construct the model's cache for use when cgeneration.
|
||||
|
||||
This function will defer the cache construction to the model if it has a
|
||||
``make_cache`` method, otherwise it will make a default KV cache.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The language model.
|
||||
max_kv_size (Optional[int]): If provided and the model does not have a
|
||||
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
|
||||
size of ``max_kv_size``
|
||||
"""
|
||||
if hasattr(model, "make_cache"):
|
||||
return model.make_cache()
|
||||
|
||||
@@ -25,6 +37,12 @@ def save_prompt_cache(
|
||||
):
|
||||
"""
|
||||
Save a pre-computed prompt cache to a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
cache (List[Any]): The model state.
|
||||
metadata (Optional[Dict[str, str]]): Optional metadata to save along
|
||||
with model state..
|
||||
"""
|
||||
cache_data, cache_info = zip(*(c.state for c in cache))
|
||||
cache_data = dict(tree_flatten(cache_data))
|
||||
@@ -39,10 +57,11 @@ def save_prompt_cache(
|
||||
def load_prompt_cache(file_name, return_metadata=False):
|
||||
"""
|
||||
Load a prompt cache from a file.
|
||||
|
||||
Args:
|
||||
file_name (str): The ``.safetensors`` file name.
|
||||
return_metadata (bool): Whether or not to return metadata. Default:
|
||||
``False``.
|
||||
return_metadata (bool): Whether or not to return metadata.
|
||||
Default: ``False``.
|
||||
|
||||
Returns:
|
||||
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and
|
||||
|
Reference in New Issue
Block a user