mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-11-07 15:28:09 +08:00
only download local shard (#1240)
This commit is contained in:
@@ -364,8 +364,29 @@ class DeepseekV2Model(nn.Module):
|
||||
DeepseekV2DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.start_idx = 0
|
||||
self.end_idx = len(self.layers)
|
||||
self.num_layers = self.end_idx
|
||||
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
|
||||
def pipeline(self, group):
|
||||
# Split layers in reverse so rank=0 gets the last layers and
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.num_layers = layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
@@ -374,14 +395,31 @@ class DeepseekV2Model(nn.Module):
|
||||
) -> mx.array:
|
||||
h = self.embed_tokens(x)
|
||||
|
||||
pipeline_rank = self.pipeline_rank
|
||||
pipeline_size = self.pipeline_size
|
||||
# Hack to avoid time-outs during prompt-processing
|
||||
dist_stream = mx.cpu if h.shape[1] > 1 else mx.gpu
|
||||
if mask is None:
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
cache = [None] * self.num_layers
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
# Receive from the previous process in the pipeline
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for i in range(self.num_layers):
|
||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
h = mx.distributed.send(
|
||||
h, (pipeline_rank - 1) % pipeline_size, stream=dist_stream
|
||||
)
|
||||
|
||||
# Broadcast h while keeping it in the graph
|
||||
h = mx.distributed.all_gather(h, stream=dist_stream)[: h.shape[0]]
|
||||
|
||||
return self.norm(h)
|
||||
|
||||
@@ -418,4 +456,4 @@ class Model(nn.Module):
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||
|
||||
@@ -381,6 +381,10 @@ class DeepseekV3Model(nn.Module):
|
||||
DeepseekV3DecoderLayer(config, idx)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
self.start_idx = 0
|
||||
self.end_idx = len(self.layers)
|
||||
self.num_layers = self.end_idx
|
||||
|
||||
self.norm = nn.RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.pipeline_rank = 0
|
||||
self.pipeline_size = 1
|
||||
@@ -394,7 +398,11 @@ class DeepseekV3Model(nn.Module):
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.layers = self.layers[start : start + layers_per_rank]
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.num_layers = layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@@ -412,15 +420,15 @@ class DeepseekV3Model(nn.Module):
|
||||
mask = create_attention_mask(h, cache)
|
||||
|
||||
if cache is None:
|
||||
cache = [None] * len(self.layers)
|
||||
cache = [None] * self.num_layers
|
||||
|
||||
# Receive from the previous process in the pipeline
|
||||
|
||||
if pipeline_rank < pipeline_size - 1:
|
||||
h = mx.distributed.recv_like(h, (pipeline_rank + 1), stream=dist_stream)
|
||||
|
||||
for layer, c in zip(self.layers, cache):
|
||||
h = layer(h, mask, c)
|
||||
for i in range(self.num_layers):
|
||||
h = self.layers[self.start_idx + i](h, mask, cache[i])
|
||||
|
||||
# Send to the next process in the pipeline
|
||||
if pipeline_rank != 0:
|
||||
@@ -468,4 +476,4 @@ class Model(nn.Module):
|
||||
|
||||
@property
|
||||
def layers(self):
|
||||
return self.model.layers
|
||||
return self.model.layers[self.model.start_idx : self.model.end_idx]
|
||||
|
||||
Reference in New Issue
Block a user