From d1c35fa684ca070ad9e79d99885a2cd4261f8b8f Mon Sep 17 00:00:00 2001 From: Konstantin Kerekovski Date: Fri, 3 May 2024 15:42:48 -0400 Subject: [PATCH] Add MLX Cache Limit setting for mlx_lm.generate and mlx_lm.server CLI (#744) * Add support for setting MLX cache limit in GB * Add support for setting MLX cache limit in GB in mlx_lm.server * format --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/generate.py | 10 ++++++++++ llms/mlx_lm/server.py | 11 +++++++++++ 2 files changed, 21 insertions(+) diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index da94eef2..477398b6 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -71,6 +71,13 @@ def setup_arg_parser(): action="store_true", help="Colorize output based on T[0] probability", ) + parser.add_argument( + "--cache-limit-gb", + type=int, + default=None, + help="Set the MLX cache limit in GB", + required=False, + ) return parser @@ -107,6 +114,9 @@ def main(): mx.random.seed(args.seed) + if args.cache_limit_gb is not None: + mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) + # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} if args.eos_token is not None: diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 868f7a2f..0523be50 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -520,6 +520,13 @@ def main(): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level (default: INFO)", ) + parser.add_argument( + "--cache-limit-gb", + type=int, + default=None, + help="Set the MLX cache limit in GB", + required=False, + ) args = parser.parse_args() logging.basicConfig( @@ -527,6 +534,10 @@ def main(): format="%(asctime)s - %(levelname)s - %(message)s", ) + if args.cache_limit_gb is not None: + logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB") + mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) + # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}