only download local shard (#1240)

This commit is contained in:
Awni Hannun
2025-02-02 13:58:44 -08:00
committed by GitHub
parent e8afb59de4
commit 9c2ef38d4d
4 changed files with 159 additions and 65 deletions

View File

@@ -381,6 +381,10 @@ class DeepseekV3Model(nn.Module):
DeepseekV3DecoderLayer(config, idx)
for idx in range(config.num_hidden_layers)
]
self.start_idx = 0
self.end_idx = len(self.layers)
self.num_layers = self.end_idx
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.pipeline_rank = 0
self.pipeline_size = 1
@@ -394,7 +398,11 @@ class DeepseekV3Model(nn.Module):
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.layers = self.layers[start : start + layers_per_rank]
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
def __call__(
self,
@@ -412,15 +420,15 @@ class DeepseekV3Model(nn.Module):
mask = create_attention_mask(h, cache)
if cache is None:
cache = [None] * len(self.layers)
cache = [None] * self.num_layers
# Receive from the previous process in the pipeline
if pipeline_rank < pipeline_size - 1:
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
for layer, c in zip(self.layers, cache):
h = layer(h, mask, c)
for i in range(self.num_layers):
h = self.layers[self.start_idx + i](h, mask, cache[i])
# Send to the next process in the pipeline
if pipeline_rank != 0:
@@ -468,4 +476,4 @@ class Model(nn.Module):
@property
def layers(self):
return self.model.layers
return self.model.layers[self.model.start_idx : self.model.end_idx]