wired in MLX LM

This commit is contained in:
Awni Hannun 2024-10-23 19:59:28 -07:00
parent 9000e280ae
commit 131ccbe6df
2 changed files with 90 additions and 39 deletions

View File

@ -201,6 +201,31 @@ requests that use the same context. See the
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
for more usage details.
### Slow Speed with Large Models
> [!NOTE]
This requires macOS 15.0 or higher to work.
Models which are large relative to the total RAM available on the machine can
be slow. `mlx-lm` will attempt to make them faster by keeping the memory
occupied by the model and cache wired. This requires macOS 15 or higher to
work.
If you see the following warning message:
> [WARNING] Generating with a model that requires ...
then the model will likely be very slow on the given machine. These cases can
be sped up for models which fit in RAM with some room to spare. To increase the
maximum wired limit, set the following `sysctl`:
```bash
sudo sysctl iogpu.wired_limit_mb=N
```
The value `N` should be larger than the size of the model in megabytes but
smaller than the memory size of the machine.
### Supported Models
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import contextlib
import copy
import glob
import importlib
@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
# Local imports
@ -39,6 +40,29 @@ class ModelNotFoundError(Exception):
super().__init__(self.message)
@contextlib.contextmanager
def wired_limit(model):
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
print(
"[WARNING] Generating with a model that requires {model_mb} MB "
"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be very slow. See the documentation for possible work-arounds: "
)
old_limit = mx.metal.set_wired_limit(max_rec_size)
try:
yield None
finally:
# TODO... expose a synchronize??
mx.zeros((1,)).item()
mx.metal.set_wired_limit(old_limit)
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
@ -330,9 +354,9 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
with wired_limit(model):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
@ -366,7 +390,9 @@ def generate(
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")