Add MLX Cache Limit setting for mlx_lm.generate and mlx_lm.server CLI (#744)

* Add support for setting MLX cache limit in GB

* Add support for setting MLX cache limit in GB in mlx_lm.server

* format

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Konstantin Kerekovski 2024-05-03 15:42:48 -04:00 committed by GitHub
parent b468091f7f
commit d1c35fa684
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 21 additions and 0 deletions

View File

@ -71,6 +71,13 @@ def setup_arg_parser():
action="store_true",
help="Colorize output based on T[0] probability",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
required=False,
)
return parser
@ -107,6 +114,9 @@ def main():
mx.random.seed(args.seed)
if args.cache_limit_gb is not None:
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}
if args.eos_token is not None:

View File

@ -520,6 +520,13 @@ def main():
choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
help="Set the logging level (default: INFO)",
)
parser.add_argument(
"--cache-limit-gb",
type=int,
default=None,
help="Set the MLX cache limit in GB",
required=False,
)
args = parser.parse_args()
logging.basicConfig(
@ -527,6 +534,10 @@ def main():
format="%(asctime)s - %(levelname)s - %(message)s",
)
if args.cache_limit_gb is not None:
logging.debug(f"Setting cache limit to {args.cache_limit_gb} GB")
mx.metal.set_cache_limit(args.cache_limit_gb * 1024 * 1024 * 1024)
# Building tokenizer_config
tokenizer_config = {"trust_remote_code": True if args.trust_remote_code else None}