only download local shard (#1240)

This commit is contained in:
Awni Hannun 2025-02-02 13:58:44 -08:00 committed by GitHub
parent e8afb59de4
commit 9c2ef38d4d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 159 additions and 65 deletions

View File

@ -4,10 +4,11 @@
Run with: Run with:
``` ```
/path/to/mpirun \ mlx.launch \
-np 2 \
--hostfile /path/to/hosts.txt \ --hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world" --backend mpi \
/path/to/pipeline_generate.py \
--prompt "hello world"
``` ```
Make sure you can run MLX over MPI on two hosts. For more information see the Make sure you can run MLX over MPI on two hosts. For more information see the
@ -17,62 +18,106 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
""" """
import argparse import argparse
import json
from pathlib import Path
import mlx.core as mx import mlx.core as mx
from huggingface_hub import snapshot_download
from mlx.utils import tree_flatten
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
from mlx_lm.utils import load_model, load_tokenizer
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model, tokenizer = load(args.model, lazy=True)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
def rprint(*args, **kwargs): def download(repo: str, allow_patterns: list[str]) -> Path:
if rank == 0: return Path(
print(*args, **kwargs) snapshot_download(
repo,
allow_patterns=allow_patterns,
)
)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): def shard_and_load(repo):
rprint(response.text, end="", flush=True) # Get model path with everything but weight safetensors
model_path = download(
args.model,
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
rprint() # Lazy load and shard model
rprint("=" * 10) model, _ = load_model(model_path, lazy=True, strict=False)
rprint(
f"Prompt: {response.prompt_tokens} tokens, " group = mx.distributed.init(backend="mpi")
f"{response.prompt_tps:.3f} tokens-per-sec" rank = group.rank()
) model.model.pipeline(group)
rprint(
f"Generation: {response.generation_tokens} tokens, " # Figure out which files we need for the local shard
f"{response.generation_tps:.3f} tokens-per-sec" with open(model_path / "model.safetensors.index.json", "r") as fid:
) weight_index = json.load(fid)["weight_map"]
rprint(f"Peak memory: {response.peak_memory:.3f} GB")
local_files = set()
for k, _ in tree_flatten(model.parameters()):
local_files.add(weight_index[k])
# Download weights for local shard
download(args.model, allow_patterns=local_files)
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path)
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu))
return model, tokenizer
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument(
"--prompt",
"-p",
default="Write a quicksort in C++.",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=256,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
group = mx.distributed.init(backend="mpi")
rank = group.rank()
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
model, tokenizer = shard_and_load(args.model)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
for response in stream_generate(
model, tokenizer, prompt, max_tokens=args.max_tokens
):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@ -364,8 +364,29 @@ class DeepseekV2Model(nn.Module):
DeepseekV2DecoderLayer(config, idx) DeepseekV2DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
def __call__( def __call__(
self, self,
x: mx.array, x: mx.array,
@ -374,14 +395,31 @@ class DeepseekV2Model(nn.Module):
) -> mx.array: ) -> mx.array:
h = self.embed_tokens(x) h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None: if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * self.num_layers
for layer, c in zip(self.layers, cache): # Receive from the previous process in the pipeline
h = layer(h, mask, c) if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h) return self.norm(h)
@ -418,4 +456,4 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@ -381,6 +381,10 @@ class DeepseekV3Model(nn.Module):
DeepseekV3DecoderLayer(config, idx) DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers) for idx in range(config.num_hidden_layers)
] ]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0 self.pipeline_rank = 0
self.pipeline_size = 1 self.pipeline_size = 1
@ -394,7 +398,11 @@ class DeepseekV3Model(nn.Module):
len(self.layers) + self.pipeline_size - 1 len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size ) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank] self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
def __call__( def __call__(
self, self,
@ -412,15 +420,15 @@ class DeepseekV3Model(nn.Module):
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
if cache is None: if cache is None:
cache = [None] * len(self.layers) cache = [None] * self.num_layers
# Receive from the previous process in the pipeline # Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1: if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for layer, c in zip(self.layers, cache): for i in range(self.num_layers):
h = layer(h, mask, c) h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline # Send to the next process in the pipeline
if pipeline_rank != 0: if pipeline_rank != 0:
@ -468,4 +476,4 @@ class Model(nn.Module):
@property @property
def layers(self): def layers(self):
return self.model.layers return self.model.layers[self.model.start_idx : self.model.end_idx]

View File

@ -627,6 +627,7 @@ def load_config(model_path: Path) -> dict:
def load_model( def load_model(
model_path: Path, model_path: Path,
lazy: bool = False, lazy: bool = False,
strict: bool = True,
model_config: dict = {}, model_config: dict = {},
get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,
) -> nn.Module: ) -> nn.Module:
@ -638,6 +639,8 @@ def load_model(
lazy (bool): If False eval the model parameters to make sure they are lazy (bool): If False eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False`` when needed. Default: ``False``
strict (bool): Whether or not to raise an exception if weights don't
match. Default: ``True``
model_config (dict, optional): Optional configuration parameters for the model_config (dict, optional): Optional configuration parameters for the
model. Defaults to an empty dictionary. model. Defaults to an empty dictionary.
get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional):
@ -660,7 +663,7 @@ def load_model(
# Try weight for back-compat # Try weight for back-compat
weight_files = glob.glob(str(model_path / "weight*.safetensors")) weight_files = glob.glob(str(model_path / "weight*.safetensors"))
if not weight_files: if not weight_files and strict:
logging.error(f"No safetensors found in {model_path}") logging.error(f"No safetensors found in {model_path}")
raise FileNotFoundError(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}")
@ -694,7 +697,7 @@ def load_model(
class_predicate=class_predicate, class_predicate=class_predicate,
) )
model.load_weights(list(weights.items())) model.load_weights(list(weights.items()), strict=strict)
if not lazy: if not lazy:
mx.eval(model.parameters()) mx.eval(model.parameters())