From 9c2ef38d4d7a6c6afc1f9c90cb26cadb95f7343e Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 2 Feb 2025 13:58:44 -0800 Subject: [PATCH] only download local shard (#1240) --- llms/mlx_lm/examples/pipeline_generate.py | 153 ++++++++++++++-------- llms/mlx_lm/models/deepseek_v2.py | 46 ++++++- llms/mlx_lm/models/deepseek_v3.py | 18 ++- llms/mlx_lm/utils.py | 7 +- 4 files changed, 159 insertions(+), 65 deletions(-) diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py index 2970b986..d170405a 100644 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -4,10 +4,11 @@ Run with: ``` -/path/to/mpirun \ - -np 2 \ +mlx.launch \ --hostfile /path/to/hosts.txt \ - python /path/to/pipeline_generate.py --prompt "hello world" + --backend mpi \ + /path/to/pipeline_generate.py \ + --prompt "hello world" ``` Make sure you can run MLX over MPI on two hosts. For more information see the @@ -17,62 +18,106 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html). """ import argparse +import json +from pathlib import Path import mlx.core as mx +from huggingface_hub import snapshot_download +from mlx.utils import tree_flatten from mlx_lm import load, stream_generate - -parser = argparse.ArgumentParser(description="LLM pipelined inference example") -parser.add_argument( - "--model", - default="mlx-community/DeepSeek-R1-3bit", - help="HF repo or path to local model.", -) -parser.add_argument( - "--prompt", - "-p", - default="Write a quicksort in C++.", - help="Message to be processed by the model ('-' reads from stdin)", -) -parser.add_argument( - "--max-tokens", - "-m", - type=int, - default=256, - help="Maximum number of tokens to generate", -) -args = parser.parse_args() - -model, tokenizer = load(args.model, lazy=True) - -messages = [{"role": "user", "content": args.prompt}] -prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) - -group = mx.distributed.init() -rank = group.rank() -model.model.pipeline(group) -mx.eval(model.parameters()) - -# Synchronize processes before generation to avoid timeout if downloading -# model for the first time. -mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) +from mlx_lm.utils import load_model, load_tokenizer -def rprint(*args, **kwargs): - if rank == 0: - print(*args, **kwargs) +def download(repo: str, allow_patterns: list[str]) -> Path: + return Path( + snapshot_download( + repo, + allow_patterns=allow_patterns, + ) + ) -for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): - rprint(response.text, end="", flush=True) +def shard_and_load(repo): + # Get model path with everything but weight safetensors + model_path = download( + args.model, + allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], + ) -rprint() -rprint("=" * 10) -rprint( - f"Prompt: {response.prompt_tokens} tokens, " - f"{response.prompt_tps:.3f} tokens-per-sec" -) -rprint( - f"Generation: {response.generation_tokens} tokens, " - f"{response.generation_tps:.3f} tokens-per-sec" -) -rprint(f"Peak memory: {response.peak_memory:.3f} GB") + # Lazy load and shard model + model, _ = load_model(model_path, lazy=True, strict=False) + + group = mx.distributed.init(backend="mpi") + rank = group.rank() + model.model.pipeline(group) + + # Figure out which files we need for the local shard + with open(model_path / "model.safetensors.index.json", "r") as fid: + weight_index = json.load(fid)["weight_map"] + + local_files = set() + for k, _ in tree_flatten(model.parameters()): + local_files.add(weight_index[k]) + + # Download weights for local shard + download(args.model, allow_patterns=local_files) + + tokenizer = load_tokenizer(model_path) + model, _ = load_model(model_path) + + # Synchronize processes before generation to avoid timeout if downloading + # model for the first time. + mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) + return model, tokenizer + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="LLM pipelined inference example") + parser.add_argument( + "--model", + default="mlx-community/DeepSeek-R1-3bit", + help="HF repo or path to local model.", + ) + parser.add_argument( + "--prompt", + "-p", + default="Write a quicksort in C++.", + help="Message to be processed by the model ('-' reads from stdin)", + ) + parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=256, + help="Maximum number of tokens to generate", + ) + args = parser.parse_args() + + group = mx.distributed.init(backend="mpi") + rank = group.rank() + + def rprint(*args, **kwargs): + if rank == 0: + print(*args, **kwargs) + + model, tokenizer = shard_and_load(args.model) + + messages = [{"role": "user", "content": args.prompt}] + prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + + for response in stream_generate( + model, tokenizer, prompt, max_tokens=args.max_tokens + ): + rprint(response.text, end="", flush=True) + + rprint() + rprint("=" * 10) + rprint( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" + ) + rprint( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" + ) + rprint(f"Peak memory: {response.peak_memory:.3f} GB") diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 9027da7e..3136ca7b 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -364,8 +364,29 @@ class DeepseekV2Model(nn.Module): DeepseekV2DecoderLayer(config, idx) for idx in range(config.num_hidden_layers) ] + self.start_idx = 0 + self.end_idx = len(self.layers) + self.num_layers = self.end_idx + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pipeline_rank = 0 + self.pipeline_size = 1 + + def pipeline(self, group): + # Split layers in reverse so rank=0 gets the last layers and + # rank=pipeline_size-1 gets the first + self.pipeline_rank = group.rank() + self.pipeline_size = group.size() + layers_per_rank = ( + len(self.layers) + self.pipeline_size - 1 + ) // self.pipeline_size + self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank + self.end_idx = self.start_idx + layers_per_rank + self.num_layers = layers_per_rank + self.layers = self.layers[: self.end_idx] + self.layers[: self.start_idx] = [None] * self.start_idx + def __call__( self, x: mx.array, @@ -374,14 +395,31 @@ class DeepseekV2Model(nn.Module): ) -> mx.array: h = self.embed_tokens(x) + pipeline_rank = self.pipeline_rank + pipeline_size = self.pipeline_size + # Hack to avoid time-outs during prompt-processing + dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu if mask is None: mask = create_attention_mask(h, cache) if cache is None: - cache = [None] * len(self.layers) + cache = [None] * self.num_layers - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) + # Receive from the previous process in the pipeline + if pipeline_rank < pipeline_size - 1: + h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) + + for i in range(self.num_layers): + h = self.layers[self.start_idx + i](h, mask, cache[i]) + + # Send to the next process in the pipeline + if pipeline_rank != 0: + h = mx.distributed.send( + h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream + ) + + # Broadcast h while keeping it in the graph + h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]] return self.norm(h) @@ -418,4 +456,4 @@ class Model(nn.Module): @property def layers(self): - return self.model.layers + return self.model.layers[self.model.start_idx : self.model.end_idx] diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 96ce4f85..e6a0dd1e 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -381,6 +381,10 @@ class DeepseekV3Model(nn.Module): DeepseekV3DecoderLayer(config, idx) for idx in range(config.num_hidden_layers) ] + self.start_idx = 0 + self.end_idx = len(self.layers) + self.num_layers = self.end_idx + self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.pipeline_rank = 0 self.pipeline_size = 1 @@ -394,7 +398,11 @@ class DeepseekV3Model(nn.Module): len(self.layers) + self.pipeline_size - 1 ) // self.pipeline_size start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank - self.layers = self.layers[start : start + layers_per_rank] + self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank + self.end_idx = self.start_idx + layers_per_rank + self.num_layers = layers_per_rank + self.layers = self.layers[: self.end_idx] + self.layers[: self.start_idx] = [None] * self.start_idx def __call__( self, @@ -412,15 +420,15 @@ class DeepseekV3Model(nn.Module): mask = create_attention_mask(h, cache) if cache is None: - cache = [None] * len(self.layers) + cache = [None] * self.num_layers # Receive from the previous process in the pipeline if pipeline_rank < pipeline_size - 1: h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) - for layer, c in zip(self.layers, cache): - h = layer(h, mask, c) + for i in range(self.num_layers): + h = self.layers[self.start_idx + i](h, mask, cache[i]) # Send to the next process in the pipeline if pipeline_rank != 0: @@ -468,4 +476,4 @@ class Model(nn.Module): @property def layers(self): - return self.model.layers + return self.model.layers[self.model.start_idx : self.model.end_idx] diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 0150f1b7..b2e89a13 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -627,6 +627,7 @@ def load_config(model_path: Path) -> dict: def load_model( model_path: Path, lazy: bool = False, + strict: bool = True, model_config: dict = {}, get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, ) -> nn.Module: @@ -638,6 +639,8 @@ def load_model( lazy (bool): If False eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` + strict (bool): Whether or not to raise an exception if weights don't + match. Default: ``True`` model_config (dict, optional): Optional configuration parameters for the model. Defaults to an empty dictionary. get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): @@ -660,7 +663,7 @@ def load_model( # Try weight for back-compat weight_files = glob.glob(str(model_path / "weight*.safetensors")) - if not weight_files: + if not weight_files and strict: logging.error(f"No safetensors found in {model_path}") raise FileNotFoundError(f"No safetensors found in {model_path}") @@ -694,7 +697,7 @@ def load_model( class_predicate=class_predicate, ) - model.load_weights(list(weights.items())) + model.load_weights(list(weights.items()), strict=strict) if not lazy: mx.eval(model.parameters())