From 22d4a20dc216d46a0ec2ada2d342e9f60e6b0344 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Sun, 5 Jan 2025 14:55:23 -0800 Subject: [PATCH] add pipeline generation and example --- llms/mlx_lm/examples/pipeline_generate.py | 75 +++++++++++++++++++++++ llms/mlx_lm/models/deepseek_v3.py | 28 ++++++++- llms/mlx_lm/utils.py | 2 +- llms/tests/test_utils_load_model.py | 2 +- 4 files changed, 104 insertions(+), 3 deletions(-) create mode 100644 llms/mlx_lm/examples/pipeline_generate.py diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py new file mode 100644 index 00000000..95f0fcb8 --- /dev/null +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -0,0 +1,75 @@ +# Copyright © 2024 Apple Inc. + +""" +Run with: + +``` +/path/to/mpirun \ + -np 2 \ + --hostfile /path/to/hosts.txt \ + python /path/to/pipeline_generate.py --prompt "hello world" +``` + +Make sure you can run MLX over MPI on two hosts. For more information see the +documentation: + +https://ml-explore.github.io/mlx/build/html/usage/distributed.html). +""" + +import argparse + +import mlx.core as mx +from mlx_lm import load, stream_generate + +parser = argparse.ArgumentParser(description="LLM pipelined inference example") +parser.add_argument( + "--prompt", + "-p", + default="Hello world", + help="Message to be processed by the model ('-' reads from stdin)", +) +parser.add_argument( + "--max-tokens", + "-m", + type=int, + default=128, + help="Maximum number of tokens to generate", +) +args = parser.parse_args() + +model_repo = "mlx-community/DeepSeek-V3-3bit-bf16" + +model, tokenizer = load(model_repo, lazy=True) + +messages = [{"role": "user", "content": args.prompt}] +prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) + +group = mx.distributed.init() +rank = group.rank() +model.model.pipeline(group) +mx.eval(model.parameters()) + +# Synchronize processes before generation to avoid timeout if downloading +# model for the first time. +mx.eval(mx.distributed.all_sum(mx.array(1.0))) + + +def rprint(*args, **kwargs): + if rank == 0: + print(*args, **kwargs) + + +for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens): + rprint(response.text, end="", flush=True) + +rprint() +rprint("=" * 10) +rprint( + f"Prompt: {response.prompt_tokens} tokens, " + f"{response.prompt_tps:.3f} tokens-per-sec" +) +rprint( + f"Generation: {response.generation_tokens} tokens, " + f"{response.generation_tps:.3f} tokens-per-sec" +) +rprint(f"Peak memory: {response.peak_memory:.3f} GB") diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index f7a1ec75..ee27a60e 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -373,6 +373,19 @@ class DeepseekV3Model(nn.Module): for idx in range(config.num_hidden_layers) ] self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.pipeline_rank = 0 + self.pipeline_size = 1 + + def pipeline(self, group): + # Split layers in reverse so rank=0 gets the last layers and + # rank=pipeline_size-1 gets the first + self.pipeline_rank = group.rank() + self.pipeline_size = group.size() + layers_per_rank = ( + len(self.layers) + self.pipeline_size - 1 + ) // self.pipeline_size + start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank + self.layers = self.layers[start : start + layers_per_rank] def __call__( self, @@ -380,17 +393,30 @@ class DeepseekV3Model(nn.Module): cache: Optional[Any] = None, mask: Optional[mx.array] = None, ) -> mx.array: - h = self.embed_tokens(x) + + pipeline_rank = self.pipeline_rank + pipeline_size = self.pipeline_size if mask is None: mask = create_attention_mask(h, cache) if cache is None: cache = [None] * len(self.layers) + # Receive from the previous process in the pipeline + if pipeline_rank < pipeline_size - 1: + h = mx.distributed.recv_like(h, (pipeline_rank + 1)) + for layer, c in zip(self.layers, cache): h = layer(h, mask, c) + # Send to the next process in the pipeline + if pipeline_rank != 0: + h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size) + + # Broadcast h while keeping it in the graph + h = mx.distributed.all_gather(h)[: h.shape[0]] + return self.norm(h) diff --git a/llms/mlx_lm/utils.py b/llms/mlx_lm/utils.py index 8e5d0be2..0e06b5a0 100644 --- a/llms/mlx_lm/utils.py +++ b/llms/mlx_lm/utils.py @@ -561,7 +561,7 @@ def load( Defaults to an empty dictionary. adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers to the model. Default: ``None``. - lazy (bool): If False eval the model parameters to make sure they are + lazy (bool): If ``False`` eval the model parameters to make sure they are loaded in memory before returning, otherwise they will be loaded when needed. Default: ``False`` Returns: diff --git a/llms/tests/test_utils_load_model.py b/llms/tests/test_utils_load_model.py index 5821f9e9..8da19afb 100644 --- a/llms/tests/test_utils_load_model.py +++ b/llms/tests/test_utils_load_model.py @@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase): self.config = args self.custom_attribute = "This is a custom model" - def load_weights(self, weights): + def load_weights(self, weights, **kwargs): self.qwenWeights = weights class CustomQwenConfig: