diff --git a/llms/mlx_lm/models/cache.py b/llms/mlx_lm/models/cache.py index b06422e5..a6a56e0a 100644 --- a/llms/mlx_lm/models/cache.py +++ b/llms/mlx_lm/models/cache.py @@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False): return cache +def can_trim_prompt_cache(cache: List[Any]) -> bool: + """ + Check if model's cache can be trimmed. + """ + return all(c.is_trimmable() for c in cache) + + def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: """ Trim the model's cache by the given number of tokens. @@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: Returns: (int): The number of tokens that were trimmed. """ - if not all(c.is_trimmable() for c in cache) or len(cache) == 0: + if not can_trim_prompt_cache(cache) or len(cache) == 0: return 0 return [c.trim(num_tokens) for c in cache][0]