mirror of
				https://github.com/ml-explore/mlx-examples.git
				synced 2025-10-31 19:18:09 +08:00 
			
		
		
		
	only download local shard (#1240)
This commit is contained in:
		| @@ -4,10 +4,11 @@ | ||||
| Run with: | ||||
|  | ||||
| ``` | ||||
| /path/to/mpirun \ | ||||
|  -np 2 \ | ||||
| mlx.launch \ | ||||
|  --hostfile /path/to/hosts.txt \ | ||||
|  python /path/to/pipeline_generate.py --prompt "hello world" | ||||
|  --backend mpi \ | ||||
|  /path/to/pipeline_generate.py \ | ||||
|  --prompt "hello world" | ||||
| ``` | ||||
|  | ||||
| Make sure you can run MLX over MPI on two hosts. For more information see the | ||||
| @@ -17,62 +18,106 @@ https://ml-explore.github.io/mlx/build/html/usage/distributed.html). | ||||
| """ | ||||
|  | ||||
| import argparse | ||||
| import json | ||||
| from pathlib import Path | ||||
|  | ||||
| import mlx.core as mx | ||||
| from huggingface_hub import snapshot_download | ||||
| from mlx.utils import tree_flatten | ||||
| from mlx_lm import load, stream_generate | ||||
|  | ||||
| parser = argparse.ArgumentParser(description="LLM pipelined inference example") | ||||
| parser.add_argument( | ||||
|     "--model", | ||||
|     default="mlx-community/DeepSeek-R1-3bit", | ||||
|     help="HF repo or path to local model.", | ||||
| ) | ||||
| parser.add_argument( | ||||
|     "--prompt", | ||||
|     "-p", | ||||
|     default="Write a quicksort in C++.", | ||||
|     help="Message to be processed by the model ('-' reads from stdin)", | ||||
| ) | ||||
| parser.add_argument( | ||||
|     "--max-tokens", | ||||
|     "-m", | ||||
|     type=int, | ||||
|     default=256, | ||||
|     help="Maximum number of tokens to generate", | ||||
| ) | ||||
| args = parser.parse_args() | ||||
|  | ||||
| model, tokenizer = load(args.model, lazy=True) | ||||
|  | ||||
| messages = [{"role": "user", "content": args.prompt}] | ||||
| prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) | ||||
|  | ||||
| group = mx.distributed.init() | ||||
| rank = group.rank() | ||||
| model.model.pipeline(group) | ||||
| mx.eval(model.parameters()) | ||||
|  | ||||
| # Synchronize processes before generation to avoid timeout if downloading | ||||
| # model for the first time. | ||||
| mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) | ||||
| from mlx_lm.utils import load_model, load_tokenizer | ||||
|  | ||||
|  | ||||
| def rprint(*args, **kwargs): | ||||
|     if rank == 0: | ||||
|         print(*args, **kwargs) | ||||
| def download(repo: str, allow_patterns: list[str]) -> Path: | ||||
|     return Path( | ||||
|         snapshot_download( | ||||
|             repo, | ||||
|             allow_patterns=allow_patterns, | ||||
|         ) | ||||
|     ) | ||||
|  | ||||
|  | ||||
| for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): | ||||
|     rprint(response.text, end="", flush=True) | ||||
| def shard_and_load(repo): | ||||
|     # Get model path with everything but weight safetensors | ||||
|     model_path = download( | ||||
|         args.model, | ||||
|         allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], | ||||
|     ) | ||||
|  | ||||
| rprint() | ||||
| rprint("=" * 10) | ||||
| rprint( | ||||
|     f"Prompt: {response.prompt_tokens} tokens, " | ||||
|     f"{response.prompt_tps:.3f} tokens-per-sec" | ||||
| ) | ||||
| rprint( | ||||
|     f"Generation: {response.generation_tokens} tokens, " | ||||
|     f"{response.generation_tps:.3f} tokens-per-sec" | ||||
| ) | ||||
| rprint(f"Peak memory: {response.peak_memory:.3f} GB") | ||||
|     # Lazy load and shard model | ||||
|     model, _ = load_model(model_path, lazy=True, strict=False) | ||||
|  | ||||
|     group = mx.distributed.init(backend="mpi") | ||||
|     rank = group.rank() | ||||
|     model.model.pipeline(group) | ||||
|  | ||||
|     # Figure out which files we need for the local shard | ||||
|     with open(model_path / "model.safetensors.index.json", "r") as fid: | ||||
|         weight_index = json.load(fid)["weight_map"] | ||||
|  | ||||
|     local_files = set() | ||||
|     for k, _ in tree_flatten(model.parameters()): | ||||
|         local_files.add(weight_index[k]) | ||||
|  | ||||
|     # Download weights for local shard | ||||
|     download(args.model, allow_patterns=local_files) | ||||
|  | ||||
|     tokenizer = load_tokenizer(model_path) | ||||
|     model, _ = load_model(model_path) | ||||
|  | ||||
|     # Synchronize processes before generation to avoid timeout if downloading | ||||
|     # model for the first time. | ||||
|     mx.eval(mx.distributed.all_sum(mx.array(1.0), stream=mx.cpu)) | ||||
|     return model, tokenizer | ||||
|  | ||||
|  | ||||
| if __name__ == "__main__": | ||||
|     parser = argparse.ArgumentParser(description="LLM pipelined inference example") | ||||
|     parser.add_argument( | ||||
|         "--model", | ||||
|         default="mlx-community/DeepSeek-R1-3bit", | ||||
|         help="HF repo or path to local model.", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--prompt", | ||||
|         "-p", | ||||
|         default="Write a quicksort in C++.", | ||||
|         help="Message to be processed by the model ('-' reads from stdin)", | ||||
|     ) | ||||
|     parser.add_argument( | ||||
|         "--max-tokens", | ||||
|         "-m", | ||||
|         type=int, | ||||
|         default=256, | ||||
|         help="Maximum number of tokens to generate", | ||||
|     ) | ||||
|     args = parser.parse_args() | ||||
|  | ||||
|     group = mx.distributed.init(backend="mpi") | ||||
|     rank = group.rank() | ||||
|  | ||||
|     def rprint(*args, **kwargs): | ||||
|         if rank == 0: | ||||
|             print(*args, **kwargs) | ||||
|  | ||||
|     model, tokenizer = shard_and_load(args.model) | ||||
|  | ||||
|     messages = [{"role": "user", "content": args.prompt}] | ||||
|     prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) | ||||
|  | ||||
|     for response in stream_generate( | ||||
|         model, tokenizer, prompt, max_tokens=args.max_tokens | ||||
|     ): | ||||
|         rprint(response.text, end="", flush=True) | ||||
|  | ||||
|     rprint() | ||||
|     rprint("=" * 10) | ||||
|     rprint( | ||||
|         f"Prompt: {response.prompt_tokens} tokens, " | ||||
|         f"{response.prompt_tps:.3f} tokens-per-sec" | ||||
|     ) | ||||
|     rprint( | ||||
|         f"Generation: {response.generation_tokens} tokens, " | ||||
|         f"{response.generation_tps:.3f} tokens-per-sec" | ||||
|     ) | ||||
|     rprint(f"Peak memory: {response.peak_memory:.3f} GB") | ||||
|   | ||||
| @@ -364,8 +364,29 @@ class DeepseekV2Model(nn.Module): | ||||
|             DeepseekV2DecoderLayer(config, idx) | ||||
|             for idx in range(config.num_hidden_layers) | ||||
|         ] | ||||
|         self.start_idx = 0 | ||||
|         self.end_idx = len(self.layers) | ||||
|         self.num_layers = self.end_idx | ||||
|  | ||||
|         self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|  | ||||
|         self.pipeline_rank = 0 | ||||
|         self.pipeline_size = 1 | ||||
|  | ||||
|     def pipeline(self, group): | ||||
|         # Split layers in reverse so rank=0 gets the last layers and | ||||
|         # rank=pipeline_size-1 gets the first | ||||
|         self.pipeline_rank = group.rank() | ||||
|         self.pipeline_size = group.size() | ||||
|         layers_per_rank = ( | ||||
|             len(self.layers) + self.pipeline_size - 1 | ||||
|         ) // self.pipeline_size | ||||
|         self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank | ||||
|         self.end_idx = self.start_idx + layers_per_rank | ||||
|         self.num_layers = layers_per_rank | ||||
|         self.layers = self.layers[: self.end_idx] | ||||
|         self.layers[: self.start_idx] = [None] * self.start_idx | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
|         x: mx.array, | ||||
| @@ -374,14 +395,31 @@ class DeepseekV2Model(nn.Module): | ||||
|     ) -> mx.array: | ||||
|         h = self.embed_tokens(x) | ||||
|  | ||||
|         pipeline_rank = self.pipeline_rank | ||||
|         pipeline_size = self.pipeline_size | ||||
|         # Hack to avoid time-outs during prompt-processing | ||||
|         dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu | ||||
|         if mask is None: | ||||
|             mask = create_attention_mask(h, cache) | ||||
|  | ||||
|         if cache is None: | ||||
|             cache = [None] * len(self.layers) | ||||
|             cache = [None] * self.num_layers | ||||
|  | ||||
|         for layer, c in zip(self.layers, cache): | ||||
|             h = layer(h, mask, c) | ||||
|         # Receive from the previous process in the pipeline | ||||
|         if pipeline_rank < pipeline_size - 1: | ||||
|             h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) | ||||
|  | ||||
|         for i in range(self.num_layers): | ||||
|             h = self.layers[self.start_idx + i](h, mask, cache[i]) | ||||
|  | ||||
|         # Send to the next process in the pipeline | ||||
|         if pipeline_rank != 0: | ||||
|             h = mx.distributed.send( | ||||
|                 h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream | ||||
|             ) | ||||
|  | ||||
|         # Broadcast h while keeping it in the graph | ||||
|         h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]] | ||||
|  | ||||
|         return self.norm(h) | ||||
|  | ||||
| @@ -418,4 +456,4 @@ class Model(nn.Module): | ||||
|  | ||||
|     @property | ||||
|     def layers(self): | ||||
|         return self.model.layers | ||||
|         return self.model.layers[self.model.start_idx : self.model.end_idx] | ||||
|   | ||||
| @@ -381,6 +381,10 @@ class DeepseekV3Model(nn.Module): | ||||
|             DeepseekV3DecoderLayer(config, idx) | ||||
|             for idx in range(config.num_hidden_layers) | ||||
|         ] | ||||
|         self.start_idx = 0 | ||||
|         self.end_idx = len(self.layers) | ||||
|         self.num_layers = self.end_idx | ||||
|  | ||||
|         self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) | ||||
|         self.pipeline_rank = 0 | ||||
|         self.pipeline_size = 1 | ||||
| @@ -394,7 +398,11 @@ class DeepseekV3Model(nn.Module): | ||||
|             len(self.layers) + self.pipeline_size - 1 | ||||
|         ) // self.pipeline_size | ||||
|         start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank | ||||
|         self.layers = self.layers[start : start + layers_per_rank] | ||||
|         self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank | ||||
|         self.end_idx = self.start_idx + layers_per_rank | ||||
|         self.num_layers = layers_per_rank | ||||
|         self.layers = self.layers[: self.end_idx] | ||||
|         self.layers[: self.start_idx] = [None] * self.start_idx | ||||
|  | ||||
|     def __call__( | ||||
|         self, | ||||
| @@ -412,15 +420,15 @@ class DeepseekV3Model(nn.Module): | ||||
|             mask = create_attention_mask(h, cache) | ||||
|  | ||||
|         if cache is None: | ||||
|             cache = [None] * len(self.layers) | ||||
|             cache = [None] * self.num_layers | ||||
|  | ||||
|         # Receive from the previous process in the pipeline | ||||
|  | ||||
|         if pipeline_rank < pipeline_size - 1: | ||||
|             h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream) | ||||
|  | ||||
|         for layer, c in zip(self.layers, cache): | ||||
|             h = layer(h, mask, c) | ||||
|         for i in range(self.num_layers): | ||||
|             h = self.layers[self.start_idx + i](h, mask, cache[i]) | ||||
|  | ||||
|         # Send to the next process in the pipeline | ||||
|         if pipeline_rank != 0: | ||||
| @@ -468,4 +476,4 @@ class Model(nn.Module): | ||||
|  | ||||
|     @property | ||||
|     def layers(self): | ||||
|         return self.model.layers | ||||
|         return self.model.layers[self.model.start_idx : self.model.end_idx] | ||||
|   | ||||
| @@ -627,6 +627,7 @@ def load_config(model_path: Path) -> dict: | ||||
| def load_model( | ||||
|     model_path: Path, | ||||
|     lazy: bool = False, | ||||
|     strict: bool = True, | ||||
|     model_config: dict = {}, | ||||
|     get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes, | ||||
| ) -> nn.Module: | ||||
| @@ -638,6 +639,8 @@ def load_model( | ||||
|         lazy (bool): If False eval the model parameters to make sure they are | ||||
|             loaded in memory before returning, otherwise they will be loaded | ||||
|             when needed. Default: ``False`` | ||||
|         strict (bool): Whether or not to raise an exception if weights don't | ||||
|             match. Default: ``True`` | ||||
|         model_config (dict, optional): Optional configuration parameters for the | ||||
|             model. Defaults to an empty dictionary. | ||||
|         get_model_classes (Callable[[dict], Tuple[Type[nn.Module], Type]], optional): | ||||
| @@ -660,7 +663,7 @@ def load_model( | ||||
|         # Try weight for back-compat | ||||
|         weight_files = glob.glob(str(model_path / "weight*.safetensors")) | ||||
|  | ||||
|     if not weight_files: | ||||
|     if not weight_files and strict: | ||||
|         logging.error(f"No safetensors found in {model_path}") | ||||
|         raise FileNotFoundError(f"No safetensors found in {model_path}") | ||||
|  | ||||
| @@ -694,7 +697,7 @@ def load_model( | ||||
|             class_predicate=class_predicate, | ||||
|         ) | ||||
|  | ||||
|     model.load_weights(list(weights.items())) | ||||
|     model.load_weights(list(weights.items()), strict=strict) | ||||
|  | ||||
|     if not lazy: | ||||
|         mx.eval(model.parameters()) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user
	 Awni Hannun
					Awni Hannun