wired in MLX LM

This commit is contained in:
Awni Hannun 2024-10-23 19:59:28 -07:00
parent 9000e280ae
commit 131ccbe6df
2 changed files with 90 additions and 39 deletions

View File

@ -201,6 +201,31 @@ requests that use the same context. See the
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py) [example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
for more usage details. for more usage details.
### Slow Speed with Large Models
> [!NOTE]
This requires macOS 15.0 or higher to work.
Models which are large relative to the total RAM available on the machine can
be slow. `mlx-lm` will attempt to make them faster by keeping the memory
occupied by the model and cache wired. This requires macOS 15 or higher to
work.
If you see the following warning message:
> [WARNING] Generating with a model that requires ...
then the model will likely be very slow on the given machine. These cases can
be sped up for models which fit in RAM with some room to spare. To increase the
maximum wired limit, set the following `sysctl`:
```bash
sudo sysctl iogpu.wired_limit_mb=N
```
The value `N` should be larger than the size of the model in megabytes but
smaller than the memory size of the machine.
### Supported Models ### Supported Models
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to `mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import contextlib
import copy import copy
import glob import glob
import importlib import importlib
@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
from huggingface_hub import snapshot_download from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
# Local imports # Local imports
@ -39,6 +40,29 @@ class ModelNotFoundError(Exception):
super().__init__(self.message) super().__init__(self.message)
@contextlib.contextmanager
def wired_limit(model):
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
print(
"[WARNING] Generating with a model that requires {model_mb} MB "
"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be very slow. See the documentation for possible work-arounds: "
)
old_limit = mx.metal.set_wired_limit(max_rec_size)
try:
yield None
finally:
# TODO... expose a synchronize??
mx.zeros((1,)).item()
mx.metal.set_wired_limit(old_limit)
def _get_classes(config: dict): def _get_classes(config: dict):
""" """
Retrieve the model and model args classes based on the configuration. Retrieve the model and model args classes based on the configuration.
@ -330,48 +354,50 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt)) prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer detokenizer = tokenizer.detokenizer
tic = time.perf_counter() with wired_limit(model):
detokenizer.reset() tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
for n, (token, logprobs) in zip( if verbose:
range(max_tokens), if formatter:
generate_step(prompt_tokens, model, **kwargs), # We have to finalize so that the prob corresponds to the last segment
): detokenizer.finalize()
if n == 0: with mx.stream(mx.cpu):
prompt_time = time.perf_counter() - tic prob = mx.exp(logprobs[token]).item()
tic = time.perf_counter() formatter(detokenizer.last_segment, prob)
if token == tokenizer.eos_token_id: else:
break print(detokenizer.last_segment, end="", flush=True)
detokenizer.add_token(token)
token_count = n + 1
detokenizer.finalize()
if verbose: if verbose:
if formatter: gen_time = time.perf_counter() - tic
# We have to finalize so that the prob corresponds to the last segment print(detokenizer.last_segment, flush=True)
detokenizer.finalize() print("=" * 10)
with mx.stream(mx.cpu): if token_count == 0:
prob = mx.exp(logprobs[token]).item() print("No tokens generated for this prompt")
formatter(detokenizer.last_segment, prob) return
else: prompt_tps = prompt_tokens.size / prompt_time
print(detokenizer.last_segment, end="", flush=True) gen_tps = (token_count - 1) / gen_time
print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
token_count = n + 1 return detokenizer.text
detokenizer.finalize()
if verbose:
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
def load_config(model_path: Path) -> dict: def load_config(model_path: Path) -> dict: