From e510987870fdc0c9741d8448fe37a776d2ee52a0 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Fri, 1 Nov 2024 14:15:32 -0700 Subject: [PATCH] Clear cache every now and then (#1081) * clear cache every now and then * don't need user arg anymore --- llms/mlx_lm/generate.py | 9 --------- llms/mlx_lm/utils.py | 4 ++++ 2 files changed, 4 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0355ca29..29976da2 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -90,12 +90,6 @@ def setup_arg_parser(): action="store_true", help="Colorize output based on T[0] probability", ) - parser.add_argument( - "--cache-limit-gb", - type=int, - default=None, - help="Set the MLX cache limit in GB", - ) parser.add_argument( "--max-kv-size", type=int, @@ -164,9 +158,6 @@ def main(): mx.random.seed(args.seed) - if args.cache_limit_gb is not None: - mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) - # Load the prompt cache and metadata if a cache file is provided using_cache = args.prompt_cache_file is not None if using_cache: diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 06784f10..b9fc202d 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -310,10 +310,14 @@ def generate_step( y, logprobs = _step(y) mx.async_eval(y, logprobs) + n = 0 while True: next_y, next_logprobs = _step(y) mx.async_eval(next_y, next_logprobs) yield y.item(), logprobs + if n % 256 == 0: + mx.metal.clear_cache() + n += 1 y, logprobs = next_y, next_logprobs