some fixes for pipeline parallel deep seek r1 (#1216)

This commit is contained in:
Awni Hannun
2025-01-21 19:40:29 -08:00
committed by GitHub
parent df1406735b
commit 9a3ddc3e65
2 changed files with 14 additions and 6 deletions

View File

@@ -400,6 +400,8 @@ class DeepseekV3Model(nn.Module):
pipeline_rank = self.pipeline_rank
pipeline_size = self.pipeline_size
# Hack to avoid time-outs during prompt-processing
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
if mask is None:
mask = create_attention_mask(h, cache)
@@ -407,18 +409,21 @@ class DeepseekV3Model(nn.Module):
cache = [None] * len(self.layers)
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1))
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
# Send to the next process in the pipeline
if pipeline_rank != 0:
h = mx.distributed.send(h, (pipeline_rank - 1) % pipeline_size)
h = mx.distributed.send(
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
)
# Broadcast h while keeping it in the graph
h = mx.distributed.all_gather(h)[: h.shape[0]]
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
return self.norm(h)