add can_trim_prompt_cache

This commit is contained in:
Awni Hannun 2024-10-09 13:02:02 -07:00
parent acec71b474
commit c9e9c75c66

View File

@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
Returns:
(int): The number of tokens that were trimmed.
"""
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]