add can_trim_prompt_cache

This commit is contained in:
Awni Hannun
2024-10-09 13:02:02 -07:00
parent acec71b474
commit c9e9c75c66

View File

@@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
return cache return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]: def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
""" """
Trim the model's cache by the given number of tokens. Trim the model's cache by the given number of tokens.
@@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
Returns: Returns:
(int): The number of tokens that were trimmed. (int): The number of tokens that were trimmed.
""" """
if not all(c.is_trimmable() for c in cache) or len(cache) == 0: if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0 return 0
return [c.trim(num_tokens) for c in cache][0] return [c.trim(num_tokens) for c in cache][0]