mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
wired in MLX LM
This commit is contained in:
parent
9000e280ae
commit
131ccbe6df
@ -201,6 +201,31 @@ requests that use the same context. See the
|
||||
[example](https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/examples/chat.py)
|
||||
for more usage details.
|
||||
|
||||
### Slow Speed with Large Models
|
||||
|
||||
> [!NOTE]
|
||||
This requires macOS 15.0 or higher to work.
|
||||
|
||||
Models which are large relative to the total RAM available on the machine can
|
||||
be slow. `mlx-lm` will attempt to make them faster by keeping the memory
|
||||
occupied by the model and cache wired. This requires macOS 15 or higher to
|
||||
work.
|
||||
|
||||
If you see the following warning message:
|
||||
|
||||
> [WARNING] Generating with a model that requires ...
|
||||
|
||||
then the model will likely be very slow on the given machine. These cases can
|
||||
be sped up for models which fit in RAM with some room to spare. To increase the
|
||||
maximum wired limit, set the following `sysctl`:
|
||||
|
||||
```bash
|
||||
sudo sysctl iogpu.wired_limit_mb=N
|
||||
```
|
||||
|
||||
The value `N` should be larger than the size of the model in megabytes but
|
||||
smaller than the memory size of the machine.
|
||||
|
||||
### Supported Models
|
||||
|
||||
`mlx-lm` supports thousands of Hugging Face format LLMs. If the model you want to
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import glob
|
||||
import importlib
|
||||
@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx.utils import tree_flatten, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
@ -39,6 +40,29 @@ class ModelNotFoundError(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def wired_limit(model):
|
||||
model_bytes = tree_reduce(
|
||||
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||
)
|
||||
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||
if model_bytes > 0.9 * max_rec_size:
|
||||
model_mb = model_bytes // 2**20
|
||||
max_rec_mb = max_rec_size // 2**20
|
||||
print(
|
||||
"[WARNING] Generating with a model that requires {model_mb} MB "
|
||||
"which is close to the maximum recommended size of {max_rec_mb} "
|
||||
"MB. This can be very slow. See the documentation for possible work-arounds: "
|
||||
)
|
||||
old_limit = mx.metal.set_wired_limit(max_rec_size)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
# TODO... expose a synchronize??
|
||||
mx.zeros((1,)).item()
|
||||
mx.metal.set_wired_limit(old_limit)
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
@ -330,48 +354,50 @@ def generate(
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
with wired_limit(model):
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
token_count = n + 1
|
||||
detokenizer.finalize()
|
||||
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
gen_time = time.perf_counter() - tic
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(
|
||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
token_count = n + 1
|
||||
detokenizer.finalize()
|
||||
|
||||
if verbose:
|
||||
gen_time = time.perf_counter() - tic
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return detokenizer.text
|
||||
return detokenizer.text
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
Loading…
Reference in New Issue
Block a user