add pipeline generation and example

This commit is contained in:
Awni Hannun 2025-01-05 14:55:23 -08:00
parent 18f380b177
commit 22d4a20dc2
4 changed files with 104 additions and 3 deletions

View File

@ -0,0 +1,75 @@
# Copyright © 2024 Apple Inc.
"""
Run with:
```
/path/to/mpirun \
-np 2 \
--hostfile /path/to/hosts.txt \
python /path/to/pipeline_generate.py --prompt "hello world"
```
Make sure you can run MLX over MPI on two hosts. For more information see the
documentation:
https://ml-explore.github.io/mlx/build/html/usage/distributed.html).
"""
import argparse
import mlx.core as mx
from mlx_lm import load, stream_generate
parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--prompt",
"-p",
default="Hello world",
help="Message to be processed by the model ('-' reads from stdin)",
)
parser.add_argument(
"--max-tokens",
"-m",
type=int,
default=128,
help="Maximum number of tokens to generate",
)
args = parser.parse_args()
model_repo = "mlx-community/DeepSeek-V3-3bit-bf16"
model, tokenizer = load(model_repo, lazy=True)
messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
group = mx.distributed.init()
rank = group.rank()
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.
mx.eval(mx.distributed.all_sum(mx.array(1.0)))
def rprint(*args, **kwargs):
if rank == 0:
print(*args, **kwargs)
for response in stream_generate(model, tokenizer, prompt, max_tokens=args.max_tokens):
rprint(response.text, end="", flush=True)
rprint()
rprint("=" * 10)
rprint(
f"Prompt: {response.prompt_tokens} tokens, "
f"{response.prompt_tps:.3f} tokens-per-sec"
)
rprint(
f"Generation: {response.generation_tokens} tokens, "
f"{response.generation_tps:.3f} tokens-per-sec"
)
rprint(f"Peak memory: {response.peak_memory:.3f} GB")

View File

@ -373,6 +373,19 @@ class DeepseekV3Model(nn.Module):
for idx in range(config.num_hidden_layers)
]
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
def pipeline(self, group):
# Split layers in reverse so rank=0 gets the last layers and
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank]
def __call__(
self,
@ -380,17 +393,30 @@ class DeepseekV3Model(nn.Module):
cache: Optional[Any] = None,
mask: Optional[mx.array] = None,
) -> mx.array:
h = self.embed_tokens(x)
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
if mask is None:
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h)[: h.shape[0]]
return self.norm(h)

View File

@ -561,7 +561,7 @@ def load(
Defaults to an empty dictionary.
adapter_path (str, optional): Path to the LoRA adapters. If provided, applies LoRA layers
to the model. Default: ``None``.
lazy (bool): If False eval the model parameters to make sure they are
lazy (bool): If ``False`` eval the model parameters to make sure they are
loaded in memory before returning, otherwise they will be loaded
when needed. Default: ``False``
Returns:

View File

@ -17,7 +17,7 @@ class TestLoadModelCustomGetClasses(unittest.TestCase):
self.config = args
self.custom_attribute = "This is a custom model"
def load_weights(self, weights):
def load_weights(self, weights, **kwargs):
self.qwenWeights = weights
class CustomQwenConfig: