diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index da94eef2..477398b6 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -71,6 +71,13 @@ def setup_arg_parser(): action="store_true", help="Colorize output based on T[0] probability", ) + parser.add_argument( + "--cache-limit-gb", + type=int, + default=None, + help="Set the MLX cache limit in GB", + required=False, + ) return parser @@ -107,6 +114,9 @@ def main(): mx.random.seed(args.seed) + if args.cache_limit_gb is not None: + mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) + # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None} if args.eos_token is not None: diff --git a/llms/mlx_lm/server.py b/llms/mlx_lm/server.py index 868f7a2f..0523be50 100644 --- a/llms/mlx_lm/server.py +++ b/llms/mlx_lm/server.py @@ -520,6 +520,13 @@ def main(): choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Set the logging level (default: INFO)", ) + parser.add_argument( + "--cache-limit-gb", + type=int, + default=None, + help="Set the MLX cache limit in GB", + required=False, + ) args = parser.parse_args() logging.basicConfig( @@ -527,6 +534,10 @@ def main(): format="%(asctime)s - %(levelname)s - %(message)s", ) + if args.cache_limit_gb is not None: + logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB") + mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024) + # Building tokenizer_config tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}