diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index 7d703ee3..009ad1dd 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -8,6 +8,18 @@ from mlx.utils import tree_flatten, tree_unflatten def make_prompt_cache(model: nn.Module, max_kv_size: Optional[int] = None) -> List[Any]: + """ + Construct the model's cache for use when cgeneration. + + This function will defer the cache construction to the model if it has a + ``make_cache`` method, otherwise it will make a default KV cache. + + Args: + model (nn.Module): The language model. + max_kv_size (Optional[int]): If provided and the model does not have a + ``make_cache`` method, a ``RotatingKVCache`` is used with a maximum + size of ``max_kv_size`` + """ if hasattr(model, "make_cache"): return model.make_cache() @@ -25,6 +37,12 @@ def save_prompt_cache( ): """ Save a pre-computed prompt cache to a file. + + Args: + file_name (str): The ``.safetensors`` file name. + cache (List[Any]): The model state. + metadata (Optional[Dict[str, str]]): Optional metadata to save along + with model state.. """ cache_data, cache_info = zip(*(c.state for c in cache)) cache_data = dict(tree_flatten(cache_data)) @@ -39,10 +57,11 @@ def save_prompt_cache( def load_prompt_cache(file_name, return_metadata=False): """ Load a prompt cache from a file. + Args: file_name (str): The ``.safetensors`` file name. - return_metadata (bool): Whether or not to return metadata. Default: - ``False``. + return_metadata (bool): Whether or not to return metadata. + Default: ``False``. Returns: List[Any] or Tuple[List[Any], Dict[str, str]]: The prompt cache and