From 9f34fdbda4527e85ab6b98d9f343f7a2972085f1 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Thu, 31 Oct 2024 08:17:14 -0700 Subject: [PATCH] Wire models in MLX LM (#1069) * wired in MLX LM * fix synch * comment + nit * version * mlx lm version * bump to 0.19.2 --- llms/README.md | 25 ++++++++ llms/mlx_lm/_version.py | 2 +- llms/mlx_lm/chat.py | 2 +- llms/mlx_lm/requirements.txt | 2 +- llms/mlx_lm/utils.py | 115 +++++++++++++++++++++++------------ 5 files changed, 104 insertions(+), 42 deletions(-) diff --git a/llms/README.md b/llms/README.md index 20863041..f539988a 100644 --- a/llms/README.md +++ b/llms/README.md @@ -248,3 +248,28 @@ model, tokenizer = load( tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True}, ) ``` + +### Large Models + +> [!NOTE] + This requires macOS 15.0 or higher to work. + +Models which are large relative to the total RAM available on the machine can +be slow. `mlx-lm` will attempt to make them faster by wiring the memory +occupied by the model and cache. This requires macOS 15 or higher to +work. + +If you see the following warning message: + +> [WARNING] Generating with a model that requires ... + +then the model will likely be slow on the given machine. If the model fits in +RAM then it can often be sped up by increasing the system wired memory limit. +To increase the limit, set the following `sysctl`: + +```bash +sudo sysctl iogpu.wired_limit_mb=N +``` + +The value `N` should be larger than the size of the model in megabytes but +smaller than the memory size of the machine. diff --git a/llms/mlx_lm/_version.py b/llms/mlx_lm/_version.py index 70239db6..3811616f 100644 --- a/llms/mlx_lm/_version.py +++ b/llms/mlx_lm/_version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.19.1" +__version__ = "0.19.3" diff --git a/llms/mlx_lm/chat.py b/llms/mlx_lm/chat.py index 7968a868..ea1a99c7 100644 --- a/llms/mlx_lm/chat.py +++ b/llms/mlx_lm/chat.py @@ -56,7 +56,7 @@ def main(): tokenizer_config={"trust_remote_code": True}, ) - print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.") + print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.") prompt_cache = make_prompt_cache(model, args.max_kv_size) while True: query = input(">> ") diff --git a/llms/mlx_lm/requirements.txt b/llms/mlx_lm/requirements.txt index 814c03cc..48012863 100644 --- a/llms/mlx_lm/requirements.txt +++ b/llms/mlx_lm/requirements.txt @@ -1,4 +1,4 @@ -mlx>=0.17.0 +mlx>=0.19.2 numpy transformers[sentencepiece]>=4.39.3 protobuf diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 92741b68..5b437c98 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import contextlib import copy import glob import importlib @@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten +from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -39,6 +40,40 @@ class ModelNotFoundError(Exception): super().__init__(self.message) +@contextlib.contextmanager +def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None): + """ + A context manager to temporarily change the wired limit. + + Note, the wired limit should not be changed during an async eval. If an + async eval could be running pass in the streams to synchronize with prior + to exiting the context manager. + """ + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] + if model_bytes > 0.9 * max_rec_size: + model_mb = model_bytes // 2**20 + max_rec_mb = max_rec_size // 2**20 + print( + "[WARNING] Generating with a model that requires {model_mb} MB " + "which is close to the maximum recommended size of {max_rec_mb} " + "MB. This can be slow. See the documentation for possible work-arounds: " + "https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models" + ) + old_limit = mx.metal.set_wired_limit(max_rec_size) + try: + yield None + finally: + if streams is not None: + for s in streams: + mx.synchronize(s) + else: + mx.synchronize() + mx.metal.set_wired_limit(old_limit) + + def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. @@ -330,48 +365,50 @@ def generate( prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer - tic = time.perf_counter() - detokenizer.reset() + with wired_limit(model): + tic = time.perf_counter() + detokenizer.reset() + for n, (token, logprobs) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + if token == tokenizer.eos_token_id: + break + detokenizer.add_token(token) - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + if verbose: + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + with mx.stream(mx.cpu): + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) + else: + print(detokenizer.last_segment, end="", flush=True) + + token_count = n + 1 + detokenizer.finalize() if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - with mx.stream(mx.cpu): - prob = mx.exp(logprobs[token]).item() - formatter(detokenizer.last_segment, prob) - else: - print(detokenizer.last_segment, end="", flush=True) + gen_time = time.perf_counter() - tic + print(detokenizer.last_segment, flush=True) + print("=" * 10) + if token_count == 0: + print("No tokens generated for this prompt") + return + prompt_tps = prompt_tokens.size / prompt_time + gen_tps = (token_count - 1) / gen_time + print( + f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" + ) + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 2**30 + print(f"Peak memory: {peak_mem:.3f} GB") - token_count = n + 1 - detokenizer.finalize() - - if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 - print(f"Peak memory: {peak_mem:.3f} GB") - - return detokenizer.text + return detokenizer.text def load_config(model_path: Path) -> dict: