Tokenizer updates + tests (#1024)

* tokenizer updates + tests

* nit

* add can_trim_prompt_cache

* nits
This commit is contained in:
Awni Hannun
2024-10-14 10:48:46 -07:00
committed by GitHub
parent 6c368f2124
commit 8dca1a2f60
4 changed files with 108 additions and 23 deletions

View File

@@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
return cache
def can_trim_prompt_cache(cache: List[Any]) -> bool:
"""
Check if model's cache can be trimmed.
"""
return all(c.is_trimmable() for c in cache)
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
"""
Trim the model's cache by the given number of tokens.
@@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
Returns:
(int): The number of tokens that were trimmed.
"""
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
if not can_trim_prompt_cache(cache) or len(cache) == 0:
return 0
return [c.trim(num_tokens) for c in cache][0]