mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Wire models in MLX LM (#1069)
* wired in MLX LM * fix synch * comment + nit * version * mlx lm version * bump to 0.19.2
This commit is contained in:
parent
8fe9539af7
commit
9f34fdbda4
@ -248,3 +248,28 @@ model, tokenizer = load(
|
||||
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
||||
)
|
||||
```
|
||||
|
||||
### Large Models
|
||||
|
||||
> [!NOTE]
|
||||
This requires macOS 15.0 or higher to work.
|
||||
|
||||
Models which are large relative to the total RAM available on the machine can
|
||||
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
|
||||
occupied by the model and cache. This requires macOS 15 or higher to
|
||||
work.
|
||||
|
||||
If you see the following warning message:
|
||||
|
||||
> [WARNING] Generating with a model that requires ...
|
||||
|
||||
then the model will likely be slow on the given machine. If the model fits in
|
||||
RAM then it can often be sped up by increasing the system wired memory limit.
|
||||
To increase the limit, set the following `sysctl`:
|
||||
|
||||
```bash
|
||||
sudo sysctl iogpu.wired_limit_mb=N
|
||||
```
|
||||
|
||||
The value `N` should be larger than the size of the model in megabytes but
|
||||
smaller than the memory size of the machine.
|
||||
|
@ -1,3 +1,3 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
__version__ = "0.19.1"
|
||||
__version__ = "0.19.3"
|
||||
|
@ -56,7 +56,7 @@ def main():
|
||||
tokenizer_config={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
||||
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||
while True:
|
||||
query = input(">> ")
|
||||
|
@ -1,4 +1,4 @@
|
||||
mlx>=0.17.0
|
||||
mlx>=0.19.2
|
||||
numpy
|
||||
transformers[sentencepiece]>=4.39.3
|
||||
protobuf
|
||||
|
@ -1,5 +1,6 @@
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import contextlib
|
||||
import copy
|
||||
import glob
|
||||
import importlib
|
||||
@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from huggingface_hub import snapshot_download
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx.utils import tree_flatten, tree_reduce
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
# Local imports
|
||||
@ -39,6 +40,40 @@ class ModelNotFoundError(Exception):
|
||||
super().__init__(self.message)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
||||
"""
|
||||
A context manager to temporarily change the wired limit.
|
||||
|
||||
Note, the wired limit should not be changed during an async eval. If an
|
||||
async eval could be running pass in the streams to synchronize with prior
|
||||
to exiting the context manager.
|
||||
"""
|
||||
model_bytes = tree_reduce(
|
||||
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||
)
|
||||
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||
if model_bytes > 0.9 * max_rec_size:
|
||||
model_mb = model_bytes // 2**20
|
||||
max_rec_mb = max_rec_size // 2**20
|
||||
print(
|
||||
"[WARNING] Generating with a model that requires {model_mb} MB "
|
||||
"which is close to the maximum recommended size of {max_rec_mb} "
|
||||
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
|
||||
)
|
||||
old_limit = mx.metal.set_wired_limit(max_rec_size)
|
||||
try:
|
||||
yield None
|
||||
finally:
|
||||
if streams is not None:
|
||||
for s in streams:
|
||||
mx.synchronize(s)
|
||||
else:
|
||||
mx.synchronize()
|
||||
mx.metal.set_wired_limit(old_limit)
|
||||
|
||||
|
||||
def _get_classes(config: dict):
|
||||
"""
|
||||
Retrieve the model and model args classes based on the configuration.
|
||||
@ -330,48 +365,50 @@ def generate(
|
||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||
detokenizer = tokenizer.detokenizer
|
||||
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
with wired_limit(model):
|
||||
tic = time.perf_counter()
|
||||
detokenizer.reset()
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
|
||||
for n, (token, logprobs) in zip(
|
||||
range(max_tokens),
|
||||
generate_step(prompt_tokens, model, **kwargs),
|
||||
):
|
||||
if n == 0:
|
||||
prompt_time = time.perf_counter() - tic
|
||||
tic = time.perf_counter()
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
detokenizer.add_token(token)
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
|
||||
token_count = n + 1
|
||||
detokenizer.finalize()
|
||||
|
||||
if verbose:
|
||||
if formatter:
|
||||
# We have to finalize so that the prob corresponds to the last segment
|
||||
detokenizer.finalize()
|
||||
with mx.stream(mx.cpu):
|
||||
prob = mx.exp(logprobs[token]).item()
|
||||
formatter(detokenizer.last_segment, prob)
|
||||
else:
|
||||
print(detokenizer.last_segment, end="", flush=True)
|
||||
gen_time = time.perf_counter() - tic
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(
|
||||
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||
)
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
token_count = n + 1
|
||||
detokenizer.finalize()
|
||||
|
||||
if verbose:
|
||||
gen_time = time.perf_counter() - tic
|
||||
print(detokenizer.last_segment, flush=True)
|
||||
print("=" * 10)
|
||||
if token_count == 0:
|
||||
print("No tokens generated for this prompt")
|
||||
return
|
||||
prompt_tps = prompt_tokens.size / prompt_time
|
||||
gen_tps = (token_count - 1) / gen_time
|
||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||
|
||||
return detokenizer.text
|
||||
return detokenizer.text
|
||||
|
||||
|
||||
def load_config(model_path: Path) -> dict:
|
||||
|
Loading…
Reference in New Issue
Block a user