mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
wired in MLX LM
This commit is contained in:
parent
9000e280ae
commit
131ccbe6df
@ -201,6 +201,31 @@ requests that use the same context. See the
|
|||||||
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
|
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
|
||||||
for more usage details.
|
for more usage details.
|
||||||
|
|
||||||
|
### Slow Speed with Large Models
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
This requires macOS 15.0 or higher to work.
|
||||||
|
|
||||||
|
Models which are large relative to the total RAM available on the machine can
|
||||||
|
be slow. `mlx-lm` will attempt to make them faster by keeping the memory
|
||||||
|
occupied by the model and cache wired. This requires macOS 15 or higher to
|
||||||
|
work.
|
||||||
|
|
||||||
|
If you see the following warning message:
|
||||||
|
|
||||||
|
> [WARNING] Generating with a model that requires ...
|
||||||
|
|
||||||
|
then the model will likely be very slow on the given machine. These cases can
|
||||||
|
be sped up for models which fit in RAM with some room to spare. To increase the
|
||||||
|
maximum wired limit, set the following `sysctl`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo sysctl iogpu.wired_limit_mb=N
|
||||||
|
```
|
||||||
|
|
||||||
|
The value `N` should be larger than the size of the model in megabytes but
|
||||||
|
smaller than the memory size of the machine.
|
||||||
|
|
||||||
### Supported Models
|
### Supported Models
|
||||||
|
|
||||||
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten, tree_reduce
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
@ -39,6 +40,29 @@ class ModelNotFoundError(Exception):
|
|||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def wired_limit(model):
|
||||||
|
model_bytes = tree_reduce(
|
||||||
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||||
|
)
|
||||||
|
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||||
|
if model_bytes > 0.9 * max_rec_size:
|
||||||
|
model_mb = model_bytes // 2**20
|
||||||
|
max_rec_mb = max_rec_size // 2**20
|
||||||
|
print(
|
||||||
|
"[WARNING] Generating with a model that requires {model_mb} MB "
|
||||||
|
"which is close to the maximum recommended size of {max_rec_mb} "
|
||||||
|
"MB. This can be very slow. See the documentation for possible work-arounds: "
|
||||||
|
)
|
||||||
|
old_limit = mx.metal.set_wired_limit(max_rec_size)
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
# TODO... expose a synchronize??
|
||||||
|
mx.zeros((1,)).item()
|
||||||
|
mx.metal.set_wired_limit(old_limit)
|
||||||
|
|
||||||
|
|
||||||
def _get_classes(config: dict):
|
def _get_classes(config: dict):
|
||||||
"""
|
"""
|
||||||
Retrieve the model and model args classes based on the configuration.
|
Retrieve the model and model args classes based on the configuration.
|
||||||
@ -330,9 +354,9 @@ def generate(
|
|||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
|
with wired_limit(model):
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
|
|
||||||
for n, (token, logprobs) in zip(
|
for n, (token, logprobs) in zip(
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
@ -366,7 +390,9 @@ def generate(
|
|||||||
return
|
return
|
||||||
prompt_tps = prompt_tokens.size / prompt_time
|
prompt_tps = prompt_tokens.size / prompt_time
|
||||||
gen_tps = (token_count - 1) / gen_time
|
gen_tps = (token_count - 1) / gen_time
|
||||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
print(
|
||||||
|
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
Loading…
Reference in New Issue
Block a user