mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
128 lines
3.5 KiB
Python
128 lines
3.5 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
"""
|
|
Run with:
|
|
|
|
```
|
|
mlx.launch \
|
|
--hostfile /path/to/hosts.txt \
|
|
--backend mpi \
|
|
/path/to/pipeline_generate.py \
|
|
--prompt "hello world"
|
|
```
|
|
|
|
Make sure you can run MLX over MPI on two hosts. For more information see the
|
|
documentation:
|
|
|
|
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
|
"""
|
|
|
|
import argparse
|
|
import json
|
|
from pathlib import Path
|
|
|
|
import mlx.core as mx
|
|
from huggingface_hub import snapshot_download
|
|
from mlx.utils import tree_flatten
|
|
from mlx_lm import load, stream_generate
|
|
from mlx_lm.utils import load_model, load_tokenizer
|
|
|
|
|
|
def download(repo: str, allow_patterns: list[str]) -> Path:
|
|
return Path(
|
|
snapshot_download(
|
|
repo,
|
|
allow_patterns=allow_patterns,
|
|
)
|
|
)
|
|
|
|
|
|
def shard_and_load(repo):
|
|
# Get model path with everything but weight safetensors
|
|
model_path = download(
|
|
args.model,
|
|
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
|
)
|
|
|
|
# Lazy load and shard model to figure out
|
|
# which weights we need
|
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
|
|
|
group = mx.distributed.init(backend="mpi")
|
|
rank = group.rank()
|
|
model.model.pipeline(group)
|
|
|
|
# Figure out which files we need for the local shard
|
|
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
|
weight_index = json.load(fid)["weight_map"]
|
|
|
|
local_files = set()
|
|
for k, _ in tree_flatten(model.parameters()):
|
|
local_files.add(weight_index[k])
|
|
|
|
# Download weights for local shard
|
|
download(args.model, allow_patterns=local_files)
|
|
|
|
# Load and shard the model, and load the weights
|
|
tokenizer = load_tokenizer(model_path)
|
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
|
model.model.pipeline(group)
|
|
mx.eval(model.parameters())
|
|
|
|
# Synchronize processes before generation to avoid timeout if downloading
|
|
# model for the first time.
|
|
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
|
return model, tokenizer
|
|
|
|
|
|
if __name__ == "__main__":
|
|
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
|
parser.add_argument(
|
|
"--model",
|
|
default="mlx-community/DeepSeek-R1-3bit",
|
|
help="HF repo or path to local model.",
|
|
)
|
|
parser.add_argument(
|
|
"--prompt",
|
|
"-p",
|
|
default="Write a quicksort in C++.",
|
|
help="Message to be processed by the model ('-' reads from stdin)",
|
|
)
|
|
parser.add_argument(
|
|
"--max-tokens",
|
|
"-m",
|
|
type=int,
|
|
default=256,
|
|
help="Maximum number of tokens to generate",
|
|
)
|
|
args = parser.parse_args()
|
|
|
|
group = mx.distributed.init(backend="mpi")
|
|
rank = group.rank()
|
|
|
|
def rprint(*args, **kwargs):
|
|
if rank == 0:
|
|
print(*args, **kwargs)
|
|
|
|
model, tokenizer = shard_and_load(args.model)
|
|
|
|
messages = [{"role": "user", "content": args.prompt}]
|
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
|
|
for response in stream_generate(
|
|
model, tokenizer, prompt, max_tokens=args.max_tokens
|
|
):
|
|
rprint(response.text, end="", flush=True)
|
|
|
|
rprint()
|
|
rprint("=" * 10)
|
|
rprint(
|
|
f"Prompt: {response.prompt_tokens} tokens, "
|
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
|
)
|
|
rprint(
|
|
f"Generation: {response.generation_tokens} tokens, "
|
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
|
)
|
|
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|