some fixes for pipeline parallel deep seek r1 (#1216)

This commit is contained in:
Awni Hannun 2025-01-21 19:40:29 -08:00 committed by GitHub
parent df1406735b
commit 9a3ddc3e65
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 14 additions and 6 deletions

View File

@ -22,6 +22,11 @@ import mlx.core as mx
from mlx_lm import load, stream_generate from mlx_lm import load, stream_generate
parser = argparse.ArgumentParser(description="LLM pipelined inference example") parser = argparse.ArgumentParser(description="LLM pipelined inference example")
parser.add_argument(
"--model",
default="mlx-community/DeepSeek-R1-3bit",
help="HF repo or path to local model.",
)
parser.add_argument( parser.add_argument(
"--prompt", "--prompt",
"-p", "-p",
@ -37,9 +42,7 @@ parser.add_argument(
) )
args = parser.parse_args() args = parser.parse_args()
model_repo = "mlx-community/DeepSeek-V3-3bit" model, tokenizer = load(args.model, lazy=True)
model, tokenizer = load(model_repo, lazy=True)
messages = [{"role": "user", "content": args.prompt}] messages = [{"role": "user", "content": args.prompt}]
prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True) prompt = tokenizer.apply_chat_template(messages, add_generation_prompt=True)

View File

@ -400,6 +400,8 @@ class DeepseekV3Model(nn.Module):
pipeline_rank = self.pipeline_rank pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None: if mask is None:
mask = create_attention_mask(h, cache) mask = create_attention_mask(h, cache)
@ -407,18 +409,21 @@ class DeepseekV3Model(nn.Module):
cache = [None] * len(self.layers) cache = [None] * len(self.layers)
# Receive from the previous process in the pipeline # Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1: if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1)) h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for layer, c in zip(self.layers, cache): for layer, c in zip(self.layers, cache):
h = layer(h, mask, c) h = layer(h, mask, c)
# Send to the next process in the pipeline # Send to the next process in the pipeline
if pipeline_rank != 0: if pipeline_rank != 0:
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size) h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph # Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h)[: h.shape[0]] h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h) return self.norm(h)