diff --git a/llms/mlx_lm/generate.py b/llms/mlx_lm/generate.py index 0bf98ab2..6689d58c 100644 --- a/llms/mlx_lm/generate.py +++ b/llms/mlx_lm/generate.py @@ -107,6 +107,14 @@ def setup_arg_parser(): default=None, help="A file containing saved KV caches to avoid recomputing them", ) + parser.add_argument( + "--wire-model", + "-w", + action="store_true", + help=("Keep the model resident in memory. This can substantially " + "speedup generation for models large relative to the machine's RAM.") + ) + return parser @@ -216,6 +224,14 @@ def main(): raise ValueError("Cannot use --colorize with --verbose=False") formatter = colorprint_by_t0 if args.colorize else None + if args.wire_model: + wired_bytes = mx.metal.wire(model) + if wired_bytes >= mx.metal.device_info()["max_recommended_working_set_size"]: + raise ValueError( + "Cannot wire a model larger than the available RAM. You may " + "be able to increase the available RAM by setting " + "`sudo sysctl iogpu.wired_limit_mb=N` to a larger value") + response = generate( model, tokenizer,