Wire models in MLX LM (#1069)

* wired in MLX LM

* fix synch

* comment + nit

* version

* mlx lm version

* bump to 0.19.2
This commit is contained in:
Awni Hannun 2024-10-31 08:17:14 -07:00 committed by GitHub
parent 8fe9539af7
commit 9f34fdbda4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 104 additions and 42 deletions

View File

@ -248,3 +248,28 @@ model, tokenizer = load(
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
)
```
### Large Models
> [!NOTE]
This requires macOS 15.0 or higher to work.
Models which are large relative to the total RAM available on the machine can
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
occupied by the model and cache. This requires macOS 15 or higher to
work.
If you see the following warning message:
> [WARNING] Generating with a model that requires ...
then the model will likely be slow on the given machine. If the model fits in
RAM then it can often be sped up by increasing the system wired memory limit.
To increase the limit, set the following `sysctl`:
```bash
sudo sysctl iogpu.wired_limit_mb=N
```
The value `N` should be larger than the size of the model in megabytes but
smaller than the memory size of the machine.

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.19.1"
__version__ = "0.19.3"

View File

@ -56,7 +56,7 @@ def main():
tokenizer_config={"trust_remote_code": True},
)
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
prompt_cache = make_prompt_cache(model, args.max_kv_size)
while True:
query = input(">> ")

View File

@ -1,4 +1,4 @@
mlx>=0.17.0
mlx>=0.19.2
numpy
transformers[sentencepiece]>=4.39.3
protobuf

View File

@ -1,5 +1,6 @@
# Copyright © 2023-2024 Apple Inc.
import contextlib
import copy
import glob
import importlib
@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
import mlx.core as mx
import mlx.nn as nn
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx.utils import tree_flatten, tree_reduce
from transformers import PreTrainedTokenizer
# Local imports
@ -39,6 +40,40 @@ class ModelNotFoundError(Exception):
super().__init__(self.message)
@contextlib.contextmanager
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
"""
A context manager to temporarily change the wired limit.
Note, the wired limit should not be changed during an async eval. If an
async eval could be running pass in the streams to synchronize with prior
to exiting the context manager.
"""
model_bytes = tree_reduce(
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
)
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
if model_bytes > 0.9 * max_rec_size:
model_mb = model_bytes // 2**20
max_rec_mb = max_rec_size // 2**20
print(
"[WARNING] Generating with a model that requires {model_mb} MB "
"which is close to the maximum recommended size of {max_rec_mb} "
"MB. This can be slow. See the documentation for possible work-arounds: "
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
)
old_limit = mx.metal.set_wired_limit(max_rec_size)
try:
yield None
finally:
if streams is not None:
for s in streams:
mx.synchronize(s)
else:
mx.synchronize()
mx.metal.set_wired_limit(old_limit)
def _get_classes(config: dict):
"""
Retrieve the model and model args classes based on the configuration.
@ -330,48 +365,50 @@ def generate(
prompt_tokens = mx.array(tokenizer.encode(prompt))
detokenizer = tokenizer.detokenizer
tic = time.perf_counter()
detokenizer.reset()
with wired_limit(model):
tic = time.perf_counter()
detokenizer.reset()
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
for n, (token, logprobs) in zip(
range(max_tokens),
generate_step(prompt_tokens, model, **kwargs),
):
if n == 0:
prompt_time = time.perf_counter() - tic
tic = time.perf_counter()
if token == tokenizer.eos_token_id:
break
detokenizer.add_token(token)
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
token_count = n + 1
detokenizer.finalize()
if verbose:
if formatter:
# We have to finalize so that the prob corresponds to the last segment
detokenizer.finalize()
with mx.stream(mx.cpu):
prob = mx.exp(logprobs[token]).item()
formatter(detokenizer.last_segment, prob)
else:
print(detokenizer.last_segment, end="", flush=True)
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
)
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
token_count = n + 1
detokenizer.finalize()
if verbose:
gen_time = time.perf_counter() - tic
print(detokenizer.last_segment, flush=True)
print("=" * 10)
if token_count == 0:
print("No tokens generated for this prompt")
return
prompt_tps = prompt_tokens.size / prompt_time
gen_tps = (token_count - 1) / gen_time
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
peak_mem = mx.metal.get_peak_memory() / 2**30
print(f"Peak memory: {peak_mem:.3f} GB")
return detokenizer.text
return detokenizer.text
def load_config(model_path: Path) -> dict: