From 131ccbe6df11ae012f6eee84ad6790499ec5400e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Wed, 23 Oct 2024 19:59:28 -0700 Subject: [PATCH] wired in MLX LM --- llms/README.md | 25 +++++++++++ llms/mlx_lm/utils.py | 104 +++++++++++++++++++++++++++---------------- 2 files changed, 90 insertions(+), 39 deletions(-) diff --git a/llms/README.md b/llms/README.md index 20863041..fd625879 100644 --- a/llms/README.md +++ b/llms/README.md @@ -201,6 +201,31 @@ requests that use the same context. See the [example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py) for more usage details. +### Slow Speed with Large Models + +> [!NOTE] + This requires macOS 15.0 or higher to work. + +Models which are large relative to the total RAM available on the machine can +be slow. `mlx-lm` will attempt to make them faster by keeping the memory +occupied by the model and cache wired. This requires macOS 15 or higher to +work. + +If you see the following warning message: + +> [WARNING] Generating with a model that requires ... + +then the model will likely be very slow on the given machine. These cases can +be sped up for models which fit in RAM with some room to spare. To increase the +maximum wired limit, set the following `sysctl`: + +```bash +sudo sysctl iogpu.wired_limit_mb=N +``` + +The value `N` should be larger than the size of the model in megabytes but +smaller than the memory size of the machine. + ### Supported Models `mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 92741b68..e9388dbe 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -1,5 +1,6 @@ # Copyright © 2023-2024 Apple Inc. +import contextlib import copy import glob import importlib @@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type, import mlx.core as mx import mlx.nn as nn from huggingface_hub import snapshot_download -from mlx.utils import tree_flatten +from mlx.utils import tree_flatten, tree_reduce from transformers import PreTrainedTokenizer # Local imports @@ -39,6 +40,29 @@ class ModelNotFoundError(Exception): super().__init__(self.message) +@contextlib.contextmanager +def wired_limit(model): + model_bytes = tree_reduce( + lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0 + ) + max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"] + if model_bytes > 0.9 * max_rec_size: + model_mb = model_bytes // 2**20 + max_rec_mb = max_rec_size // 2**20 + print( + "[WARNING] Generating with a model that requires {model_mb} MB " + "which is close to the maximum recommended size of {max_rec_mb} " + "MB. This can be very slow. See the documentation for possible work-arounds: " + ) + old_limit = mx.metal.set_wired_limit(max_rec_size) + try: + yield None + finally: + # TODO... expose a synchronize?? + mx.zeros((1,)).item() + mx.metal.set_wired_limit(old_limit) + + def _get_classes(config: dict): """ Retrieve the model and model args classes based on the configuration. @@ -330,48 +354,50 @@ def generate( prompt_tokens = mx.array(tokenizer.encode(prompt)) detokenizer = tokenizer.detokenizer - tic = time.perf_counter() - detokenizer.reset() + with wired_limit(model): + tic = time.perf_counter() + detokenizer.reset() + for n, (token, logprobs) in zip( + range(max_tokens), + generate_step(prompt_tokens, model, **kwargs), + ): + if n == 0: + prompt_time = time.perf_counter() - tic + tic = time.perf_counter() + if token == tokenizer.eos_token_id: + break + detokenizer.add_token(token) - for n, (token, logprobs) in zip( - range(max_tokens), - generate_step(prompt_tokens, model, **kwargs), - ): - if n == 0: - prompt_time = time.perf_counter() - tic - tic = time.perf_counter() - if token == tokenizer.eos_token_id: - break - detokenizer.add_token(token) + if verbose: + if formatter: + # We have to finalize so that the prob corresponds to the last segment + detokenizer.finalize() + with mx.stream(mx.cpu): + prob = mx.exp(logprobs[token]).item() + formatter(detokenizer.last_segment, prob) + else: + print(detokenizer.last_segment, end="", flush=True) + + token_count = n + 1 + detokenizer.finalize() if verbose: - if formatter: - # We have to finalize so that the prob corresponds to the last segment - detokenizer.finalize() - with mx.stream(mx.cpu): - prob = mx.exp(logprobs[token]).item() - formatter(detokenizer.last_segment, prob) - else: - print(detokenizer.last_segment, end="", flush=True) + gen_time = time.perf_counter() - tic + print(detokenizer.last_segment, flush=True) + print("=" * 10) + if token_count == 0: + print("No tokens generated for this prompt") + return + prompt_tps = prompt_tokens.size / prompt_time + gen_tps = (token_count - 1) / gen_time + print( + f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec" + ) + print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") + peak_mem = mx.metal.get_peak_memory() / 2**30 + print(f"Peak memory: {peak_mem:.3f} GB") - token_count = n + 1 - detokenizer.finalize() - - if verbose: - gen_time = time.perf_counter() - tic - print(detokenizer.last_segment, flush=True) - print("=" * 10) - if token_count == 0: - print("No tokens generated for this prompt") - return - prompt_tps = prompt_tokens.size / prompt_time - gen_tps = (token_count - 1) / gen_time - print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec") - print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec") - peak_mem = mx.metal.get_peak_memory() / 2**30 - print(f"Peak memory: {peak_mem:.3f} GB") - - return detokenizer.text + return detokenizer.text def load_config(model_path: Path) -> dict: