mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
add can_trim_prompt_cache
This commit is contained in:
parent
acec71b474
commit
c9e9c75c66
@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
|
||||
return cache
|
||||
|
||||
|
||||
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||
"""
|
||||
Check if model's cache can be trimmed.
|
||||
"""
|
||||
return all(c.is_trimmable() for c in cache)
|
||||
|
||||
|
||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
"""
|
||||
Trim the model's cache by the given number of tokens.
|
||||
@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||
Returns:
|
||||
(int): The number of tokens that were trimmed.
|
||||
"""
|
||||
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
||||
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
||||
return 0
|
||||
return [c.trim(num_tokens) for c in cache][0]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user