mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
only download local shard (#1240)
This commit is contained in:
parent
e8afb59de4
commit
9c2ef38d4d
@ -4,10 +4,11 @@
|
|||||||
Run with:
|
Run with:
|
||||||
|
|
||||||
```
|
```
|
||||||
/path/to/mpirun \
|
mlx.launch \
|
||||||
-np 2 \
|
|
||||||
--hostfile /path/to/hosts.txt \
|
--hostfile /path/to/hosts.txt \
|
||||||
python /path/to/pipeline_generate.py --prompt "hello world"
|
--backend mpi \
|
||||||
|
/path/to/pipeline_generate.py \
|
||||||
|
--prompt "hello world"
|
||||||
```
|
```
|
||||||
|
|
||||||
Make sure you can run MLX over MPI on two hosts. For more information see the
|
Make sure you can run MLX over MPI on two hosts. For more information see the
|
||||||
@ -17,62 +18,106 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
from huggingface_hub import snapshot_download
|
||||||
|
from mlx.utils import tree_flatten
|
||||||
from mlx_lm import load, stream_generate
|
from mlx_lm import load, stream_generate
|
||||||
|
from mlx_lm.utils import load_model, load_tokenizer
|
||||||
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
|
||||||
parser.add_argument(
|
|
||||||
"--model",
|
|
||||||
default="mlx-community/DeepSeek-R1-3bit",
|
|
||||||
help="HF repo or path to local model.",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--prompt",
|
|
||||||
"-p",
|
|
||||||
default="Write a quicksort in C++.",
|
|
||||||
help="Message to be processed by the model ('-' reads from stdin)",
|
|
||||||
)
|
|
||||||
parser.add_argument(
|
|
||||||
"--max-tokens",
|
|
||||||
"-m",
|
|
||||||
type=int,
|
|
||||||
default=256,
|
|
||||||
help="Maximum number of tokens to generate",
|
|
||||||
)
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
model, tokenizer = load(args.model, lazy=True)
|
|
||||||
|
|
||||||
messages = [{"role": "user", "content": args.prompt}]
|
|
||||||
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
|
||||||
|
|
||||||
group = mx.distributed.init()
|
|
||||||
rank = group.rank()
|
|
||||||
model.model.pipeline(group)
|
|
||||||
mx.eval(model.parameters())
|
|
||||||
|
|
||||||
# Synchronize processes before generation to avoid timeout if downloading
|
|
||||||
# model for the first time.
|
|
||||||
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
|
||||||
|
|
||||||
|
|
||||||
def rprint(*args, **kwargs):
|
def download(repo: str, allow_patterns: list[str]) -> Path:
|
||||||
if rank == 0:
|
return Path(
|
||||||
print(*args, **kwargs)
|
snapshot_download(
|
||||||
|
repo,
|
||||||
|
allow_patterns=allow_patterns,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
|
def shard_and_load(repo):
|
||||||
rprint(response.text, end="", flush=True)
|
# Get model path with everything but weight safetensors
|
||||||
|
model_path = download(
|
||||||
|
args.model,
|
||||||
|
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||||
|
)
|
||||||
|
|
||||||
rprint()
|
# Lazy load and shard model
|
||||||
rprint("=" * 10)
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
rprint(
|
|
||||||
f"Prompt: {response.prompt_tokens} tokens, "
|
group = mx.distributed.init(backend="mpi")
|
||||||
f"{response.prompt_tps:.3f} tokens-per-sec"
|
rank = group.rank()
|
||||||
)
|
model.model.pipeline(group)
|
||||||
rprint(
|
|
||||||
f"Generation: {response.generation_tokens} tokens, "
|
# Figure out which files we need for the local shard
|
||||||
f"{response.generation_tps:.3f} tokens-per-sec"
|
with open(model_path / "model.safetensors.index.json", "r") as fid:
|
||||||
)
|
weight_index = json.load(fid)["weight_map"]
|
||||||
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
|
||||||
|
local_files = set()
|
||||||
|
for k, _ in tree_flatten(model.parameters()):
|
||||||
|
local_files.add(weight_index[k])
|
||||||
|
|
||||||
|
# Download weights for local shard
|
||||||
|
download(args.model, allow_patterns=local_files)
|
||||||
|
|
||||||
|
tokenizer = load_tokenizer(model_path)
|
||||||
|
model, _ = load_model(model_path)
|
||||||
|
|
||||||
|
# Synchronize processes before generation to avoid timeout if downloading
|
||||||
|
# model for the first time.
|
||||||
|
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
|
||||||
|
return model, tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
|
||||||
|
parser.add_argument(
|
||||||
|
"--model",
|
||||||
|
default="mlx-community/DeepSeek-R1-3bit",
|
||||||
|
help="HF repo or path to local model.",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--prompt",
|
||||||
|
"-p",
|
||||||
|
default="Write a quicksort in C++.",
|
||||||
|
help="Message to be processed by the model ('-' reads from stdin)",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--max-tokens",
|
||||||
|
"-m",
|
||||||
|
type=int,
|
||||||
|
default=256,
|
||||||
|
help="Maximum number of tokens to generate",
|
||||||
|
)
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
group = mx.distributed.init(backend="mpi")
|
||||||
|
rank = group.rank()
|
||||||
|
|
||||||
|
def rprint(*args, **kwargs):
|
||||||
|
if rank == 0:
|
||||||
|
print(*args, **kwargs)
|
||||||
|
|
||||||
|
model, tokenizer = shard_and_load(args.model)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": args.prompt}]
|
||||||
|
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
|
||||||
|
|
||||||
|
for response in stream_generate(
|
||||||
|
model, tokenizer, prompt, max_tokens=args.max_tokens
|
||||||
|
):
|
||||||
|
rprint(response.text, end="", flush=True)
|
||||||
|
|
||||||
|
rprint()
|
||||||
|
rprint("=" * 10)
|
||||||
|
rprint(
|
||||||
|
f"Prompt: {response.prompt_tokens} tokens, "
|
||||||
|
f"{response.prompt_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
|
rprint(
|
||||||
|
f"Generation: {response.generation_tokens} tokens, "
|
||||||
|
f"{response.generation_tps:.3f} tokens-per-sec"
|
||||||
|
)
|
||||||
|
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
|
||||||
|
@ -364,8 +364,29 @@ class DeepseekV2Model(nn.Module):
|
|||||||
DeepseekV2DecoderLayer(config, idx)
|
DeepseekV2DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
self.start_idx = 0
|
||||||
|
self.end_idx = len(self.layers)
|
||||||
|
self.num_layers = self.end_idx
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
|
|
||||||
|
self.pipeline_rank = 0
|
||||||
|
self.pipeline_size = 1
|
||||||
|
|
||||||
|
def pipeline(self, group):
|
||||||
|
# Split layers in reverse so rank=0 gets the last layers and
|
||||||
|
# rank=pipeline_size-1 gets the first
|
||||||
|
self.pipeline_rank = group.rank()
|
||||||
|
self.pipeline_size = group.size()
|
||||||
|
layers_per_rank = (
|
||||||
|
len(self.layers) + self.pipeline_size - 1
|
||||||
|
) // self.pipeline_size
|
||||||
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
|
self.num_layers = layers_per_rank
|
||||||
|
self.layers = self.layers[: self.end_idx]
|
||||||
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
x: mx.array,
|
x: mx.array,
|
||||||
@ -374,14 +395,31 @@ class DeepseekV2Model(nn.Module):
|
|||||||
) -> mx.array:
|
) -> mx.array:
|
||||||
h = self.embed_tokens(x)
|
h = self.embed_tokens(x)
|
||||||
|
|
||||||
|
pipeline_rank = self.pipeline_rank
|
||||||
|
pipeline_size = self.pipeline_size
|
||||||
|
# Hack to avoid time-outs during prompt-processing
|
||||||
|
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * self.num_layers
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
# Receive from the previous process in the pipeline
|
||||||
h = layer(h, mask, c)
|
if pipeline_rank < pipeline_size - 1:
|
||||||
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||||
|
|
||||||
|
for i in range(self.num_layers):
|
||||||
|
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||||
|
|
||||||
|
# Send to the next process in the pipeline
|
||||||
|
if pipeline_rank != 0:
|
||||||
|
h = mx.distributed.send(
|
||||||
|
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||||
|
)
|
||||||
|
|
||||||
|
# Broadcast h while keeping it in the graph
|
||||||
|
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||||
|
|
||||||
return self.norm(h)
|
return self.norm(h)
|
||||||
|
|
||||||
@ -418,4 +456,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||||
|
@ -381,6 +381,10 @@ class DeepseekV3Model(nn.Module):
|
|||||||
DeepseekV3DecoderLayer(config, idx)
|
DeepseekV3DecoderLayer(config, idx)
|
||||||
for idx in range(config.num_hidden_layers)
|
for idx in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
|
self.start_idx = 0
|
||||||
|
self.end_idx = len(self.layers)
|
||||||
|
self.num_layers = self.end_idx
|
||||||
|
|
||||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||||
self.pipeline_rank = 0
|
self.pipeline_rank = 0
|
||||||
self.pipeline_size = 1
|
self.pipeline_size = 1
|
||||||
@ -394,7 +398,11 @@ class DeepseekV3Model(nn.Module):
|
|||||||
len(self.layers) + self.pipeline_size - 1
|
len(self.layers) + self.pipeline_size - 1
|
||||||
) // self.pipeline_size
|
) // self.pipeline_size
|
||||||
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.layers = self.layers[start : start + layers_per_rank]
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
|
self.num_layers = layers_per_rank
|
||||||
|
self.layers = self.layers[: self.end_idx]
|
||||||
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -412,15 +420,15 @@ class DeepseekV3Model(nn.Module):
|
|||||||
mask = create_attention_mask(h, cache)
|
mask = create_attention_mask(h, cache)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
cache = [None] * len(self.layers)
|
cache = [None] * self.num_layers
|
||||||
|
|
||||||
# Receive from the previous process in the pipeline
|
# Receive from the previous process in the pipeline
|
||||||
|
|
||||||
if pipeline_rank < pipeline_size - 1:
|
if pipeline_rank < pipeline_size - 1:
|
||||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||||
|
|
||||||
for layer, c in zip(self.layers, cache):
|
for i in range(self.num_layers):
|
||||||
h = layer(h, mask, c)
|
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||||
|
|
||||||
# Send to the next process in the pipeline
|
# Send to the next process in the pipeline
|
||||||
if pipeline_rank != 0:
|
if pipeline_rank != 0:
|
||||||
@ -468,4 +476,4 @@ class Model(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def layers(self):
|
def layers(self):
|
||||||
return self.model.layers
|
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||||
|
@ -627,6 +627,7 @@ def load_config(model_path: Path) -> dict:
|
|||||||
def load_model(
|
def load_model(
|
||||||
model_path: Path,
|
model_path: Path,
|
||||||
lazy: bool = False,
|
lazy: bool = False,
|
||||||
|
strict: bool = True,
|
||||||
model_config: dict = {},
|
model_config: dict = {},
|
||||||
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
|
||||||
) -> nn.Module:
|
) -> nn.Module:
|
||||||
@ -638,6 +639,8 @@ def load_model(
|
|||||||
lazy (bool): If False eval the model parameters to make sure they are
|
lazy (bool): If False eval the model parameters to make sure they are
|
||||||
loaded in memory before returning, otherwise they will be loaded
|
loaded in memory before returning, otherwise they will be loaded
|
||||||
when needed. Default: ``False``
|
when needed. Default: ``False``
|
||||||
|
strict (bool): Whether or not to raise an exception if weights don't
|
||||||
|
match. Default: ``True``
|
||||||
model_config (dict, optional): Optional configuration parameters for the
|
model_config (dict, optional): Optional configuration parameters for the
|
||||||
model. Defaults to an empty dictionary.
|
model. Defaults to an empty dictionary.
|
||||||
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
|
||||||
@ -660,7 +663,7 @@ def load_model(
|
|||||||
# Try weight for back-compat
|
# Try weight for back-compat
|
||||||
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
weight_files = glob.glob(str(model_path / "weight*.safetensors"))
|
||||||
|
|
||||||
if not weight_files:
|
if not weight_files and strict:
|
||||||
logging.error(f"No safetensors found in {model_path}")
|
logging.error(f"No safetensors found in {model_path}")
|
||||||
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
raise FileNotFoundError(f"No safetensors found in {model_path}")
|
||||||
|
|
||||||
@ -694,7 +697,7 @@ def load_model(
|
|||||||
class_predicate=class_predicate,
|
class_predicate=class_predicate,
|
||||||
)
|
)
|
||||||
|
|
||||||
model.load_weights(list(weights.items()))
|
model.load_weights(list(weights.items()), strict=strict)
|
||||||
|
|
||||||
if not lazy:
|
if not lazy:
|
||||||
mx.eval(model.parameters())
|
mx.eval(model.parameters())
|
||||||
|
Loading…
Reference in New Issue
Block a user