fix tests

This commit is contained in:
Awni Hannun
2024-10-05 15:32:07 -07:00
parent 62dbd418d9
commit 4dc3cc0300

View File

@@ -8,6 +8,18 @@ from mlx.utils import tree_flatten, tree_unflatten
def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]:
"""
Construct the model's cache for use when cgeneration.
This function will defer the cache construction to the model if it has a
``make_cache`` method, otherwise it will make a default KV cache.
Args:
model (nn.Module): The language model.
max_kv_size (Optional[int]): If provided and the model does not have a
``make_cache`` method, a ``RotatingKVCache`` is used with a maximum
size of ``max_kv_size``
"""
if hasattr(model, "make_cache"):
return model.make_cache()
@@ -25,6 +37,12 @@ def save_prompt_cache(
):
"""
Save a pre-computed prompt cache to a file.
Args:
file_name (str): The ``.safetensors`` file name.
cache (List[Any]): The model state.
metadata (Optional[Dict[str, str]]): Optional metadata to save along
with model state..
"""
cache_data, cache_info = zip(*(c.state for c in cache))
cache_data = dict(tree_flatten(cache_data))
@@ -39,10 +57,11 @@ def save_prompt_cache(
def load_prompt_cache(file_name, return_metadata=False):
"""
Load a prompt cache from a file.
Args:
file_name (str): The ``.safetensors`` file name.
return_metadata (bool): Whether or not to return metadata. Default:
``False``.
return_metadata (bool): Whether or not to return metadata.
Default: ``False``.
Returns:
List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and