mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Wire models in MLX LM (#1069)
* wired in MLX LM * fix synch * comment + nit * version * mlx lm version * bump to 0.19.2
This commit is contained in:
parent
8fe9539af7
commit
9f34fdbda4
@ -248,3 +248,28 @@ model, tokenizer = load(
|
|||||||
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
tokenizer_config={"eos_token": "<|endoftext|>", "trust_remote_code": True},
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Large Models
|
||||||
|
|
||||||
|
> [!NOTE]
|
||||||
|
This requires macOS 15.0 or higher to work.
|
||||||
|
|
||||||
|
Models which are large relative to the total RAM available on the machine can
|
||||||
|
be slow. `mlx-lm` will attempt to make them faster by wiring the memory
|
||||||
|
occupied by the model and cache. This requires macOS 15 or higher to
|
||||||
|
work.
|
||||||
|
|
||||||
|
If you see the following warning message:
|
||||||
|
|
||||||
|
> [WARNING] Generating with a model that requires ...
|
||||||
|
|
||||||
|
then the model will likely be slow on the given machine. If the model fits in
|
||||||
|
RAM then it can often be sped up by increasing the system wired memory limit.
|
||||||
|
To increase the limit, set the following `sysctl`:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
sudo sysctl iogpu.wired_limit_mb=N
|
||||||
|
```
|
||||||
|
|
||||||
|
The value `N` should be larger than the size of the model in megabytes but
|
||||||
|
smaller than the memory size of the machine.
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.19.1"
|
__version__ = "0.19.3"
|
||||||
|
@ -56,7 +56,7 @@ def main():
|
|||||||
tokenizer_config={"trust_remote_code": True},
|
tokenizer_config={"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
print(f"[INFO] Starting chat sessiong with {args.model}. To exit, enter 'q'.")
|
print(f"[INFO] Starting chat session with {args.model}. To exit, enter 'q'.")
|
||||||
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
prompt_cache = make_prompt_cache(model, args.max_kv_size)
|
||||||
while True:
|
while True:
|
||||||
query = input(">> ")
|
query = input(">> ")
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
mlx>=0.17.0
|
mlx>=0.19.2
|
||||||
numpy
|
numpy
|
||||||
transformers[sentencepiece]>=4.39.3
|
transformers[sentencepiece]>=4.39.3
|
||||||
protobuf
|
protobuf
|
||||||
|
@ -1,5 +1,6 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
import contextlib
|
||||||
import copy
|
import copy
|
||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
@ -14,7 +15,7 @@ from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Type,
|
|||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
from huggingface_hub import snapshot_download
|
from huggingface_hub import snapshot_download
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten, tree_reduce
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
# Local imports
|
# Local imports
|
||||||
@ -39,6 +40,40 @@ class ModelNotFoundError(Exception):
|
|||||||
super().__init__(self.message)
|
super().__init__(self.message)
|
||||||
|
|
||||||
|
|
||||||
|
@contextlib.contextmanager
|
||||||
|
def wired_limit(model: nn.Module, streams: Optional[List[mx.Stream]] = None):
|
||||||
|
"""
|
||||||
|
A context manager to temporarily change the wired limit.
|
||||||
|
|
||||||
|
Note, the wired limit should not be changed during an async eval. If an
|
||||||
|
async eval could be running pass in the streams to synchronize with prior
|
||||||
|
to exiting the context manager.
|
||||||
|
"""
|
||||||
|
model_bytes = tree_reduce(
|
||||||
|
lambda acc, x: acc + x.nbytes if isinstance(x, mx.array) else acc, model, 0
|
||||||
|
)
|
||||||
|
max_rec_size = mx.metal.device_info()["max_recommended_working_set_size"]
|
||||||
|
if model_bytes > 0.9 * max_rec_size:
|
||||||
|
model_mb = model_bytes // 2**20
|
||||||
|
max_rec_mb = max_rec_size // 2**20
|
||||||
|
print(
|
||||||
|
"[WARNING] Generating with a model that requires {model_mb} MB "
|
||||||
|
"which is close to the maximum recommended size of {max_rec_mb} "
|
||||||
|
"MB. This can be slow. See the documentation for possible work-arounds: "
|
||||||
|
"https://github.com/ml-explore/mlx-examples/tree/main/llms#large-models"
|
||||||
|
)
|
||||||
|
old_limit = mx.metal.set_wired_limit(max_rec_size)
|
||||||
|
try:
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
if streams is not None:
|
||||||
|
for s in streams:
|
||||||
|
mx.synchronize(s)
|
||||||
|
else:
|
||||||
|
mx.synchronize()
|
||||||
|
mx.metal.set_wired_limit(old_limit)
|
||||||
|
|
||||||
|
|
||||||
def _get_classes(config: dict):
|
def _get_classes(config: dict):
|
||||||
"""
|
"""
|
||||||
Retrieve the model and model args classes based on the configuration.
|
Retrieve the model and model args classes based on the configuration.
|
||||||
@ -330,9 +365,9 @@ def generate(
|
|||||||
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
prompt_tokens = mx.array(tokenizer.encode(prompt))
|
||||||
detokenizer = tokenizer.detokenizer
|
detokenizer = tokenizer.detokenizer
|
||||||
|
|
||||||
|
with wired_limit(model):
|
||||||
tic = time.perf_counter()
|
tic = time.perf_counter()
|
||||||
detokenizer.reset()
|
detokenizer.reset()
|
||||||
|
|
||||||
for n, (token, logprobs) in zip(
|
for n, (token, logprobs) in zip(
|
||||||
range(max_tokens),
|
range(max_tokens),
|
||||||
generate_step(prompt_tokens, model, **kwargs),
|
generate_step(prompt_tokens, model, **kwargs),
|
||||||
@ -366,7 +401,9 @@ def generate(
|
|||||||
return
|
return
|
||||||
prompt_tps = prompt_tokens.size / prompt_time
|
prompt_tps = prompt_tokens.size / prompt_time
|
||||||
gen_tps = (token_count - 1) / gen_time
|
gen_tps = (token_count - 1) / gen_time
|
||||||
print(f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec")
|
print(
|
||||||
|
f"Prompt: {prompt_tokens.size} tokens, {prompt_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
print(f"Generation: {token_count} tokens, {gen_tps:.3f} tokens-per-sec")
|
||||||
peak_mem = mx.metal.get_peak_memory() / 2**30
|
peak_mem = mx.metal.get_peak_memory() / 2**30
|
||||||
print(f"Peak memory: {peak_mem:.3f} GB")
|
print(f"Peak memory: {peak_mem:.3f} GB")
|
||||||
|
Loading…
Reference in New Issue
Block a user