mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 12:49:50 +08:00
add can_trim_prompt_cache
This commit is contained in:
@@ -77,6 +77,13 @@ def load_prompt_cache(file_name, return_metadata=False):
|
|||||||
return cache
|
return cache
|
||||||
|
|
||||||
|
|
||||||
|
def can_trim_prompt_cache(cache: List[Any]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model's cache can be trimmed.
|
||||||
|
"""
|
||||||
|
return all(c.is_trimmable() for c in cache)
|
||||||
|
|
||||||
|
|
||||||
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
||||||
"""
|
"""
|
||||||
Trim the model's cache by the given number of tokens.
|
Trim the model's cache by the given number of tokens.
|
||||||
@@ -91,7 +98,7 @@ def trim_prompt_cache(cache: List[Any], num_tokens: int) -> List[Any]:
|
|||||||
Returns:
|
Returns:
|
||||||
(int): The number of tokens that were trimmed.
|
(int): The number of tokens that were trimmed.
|
||||||
"""
|
"""
|
||||||
if not all(c.is_trimmable() for c in cache) or len(cache) == 0:
|
if not can_trim_prompt_cache(cache) or len(cache) == 0:
|
||||||
return 0
|
return 0
|
||||||
return [c.trim(num_tokens) for c in cache][0]
|
return [c.trim(num_tokens) for c in cache][0]
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user