mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix sharding for more even number of layers (#1276)
This commit is contained in:
parent
e879ea70e1
commit
f8cbf159e0
@ -378,9 +378,11 @@ class DeepseekV2Model(nn.Module):
|
|||||||
# rank=pipeline_size-1 gets the first
|
# rank=pipeline_size-1 gets the first
|
||||||
self.pipeline_rank = group.rank()
|
self.pipeline_rank = group.rank()
|
||||||
self.pipeline_size = group.size()
|
self.pipeline_size = group.size()
|
||||||
layers_per_rank = (
|
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||||
len(self.layers) + self.pipeline_size - 1
|
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||||
) // self.pipeline_size
|
if self.pipeline_rank < extra:
|
||||||
|
layers_per_rank += 1
|
||||||
|
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
self.num_layers = layers_per_rank
|
self.num_layers = layers_per_rank
|
||||||
|
@ -410,9 +410,10 @@ class DeepseekV3Model(nn.Module):
|
|||||||
# rank=pipeline_size-1 gets the first
|
# rank=pipeline_size-1 gets the first
|
||||||
self.pipeline_rank = group.rank()
|
self.pipeline_rank = group.rank()
|
||||||
self.pipeline_size = group.size()
|
self.pipeline_size = group.size()
|
||||||
layers_per_rank = (
|
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||||
len(self.layers) + self.pipeline_size - 1
|
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||||
) // self.pipeline_size
|
if self.pipeline_rank < extra:
|
||||||
|
layers_per_rank += 1
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
self.layers = self.layers[: self.end_idx]
|
self.layers = self.layers[: self.end_idx]
|
||||||
|
Loading…
Reference in New Issue
Block a user