mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 04:12:16 +08:00
fix sharding for more even number of layers
This commit is contained in:
parent
f58c7de901
commit
361e3547e8
@ -378,9 +378,11 @@ class DeepseekV2Model(nn.Module):
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||
if self.pipeline_rank < extra:
|
||||
layers_per_rank += 1
|
||||
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.num_layers = layers_per_rank
|
||||
|
@ -410,9 +410,10 @@ class DeepseekV3Model(nn.Module):
|
||||
# rank=pipeline_size-1 gets the first
|
||||
self.pipeline_rank = group.rank()
|
||||
self.pipeline_size = group.size()
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
layers_per_rank = len(self.layers) // self.pipeline_size
|
||||
extra = len(self.layers) - layers_per_rank * self.pipeline_size
|
||||
if self.pipeline_rank < extra:
|
||||
layers_per_rank += 1
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
|
Loading…
Reference in New Issue
Block a user