fix sharding for more even number of layers (#1276)

This commit is contained in:
Awni Hannun 2025-02-11 16:26:59 -08:00 committed by GitHub
parent e879ea70e1
commit f8cbf159e0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 9 additions and 6 deletions

View File

@ -378,9 +378,11 @@ class DeepseekV2Model(nn.Module):
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
layers_per_rank = len(self.layers) // self.pipeline_size
extra = len(self.layers) - layers_per_rank * self.pipeline_size
if self.pipeline_rank < extra:
layers_per_rank += 1
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank

View File

@ -410,9 +410,10 @@ class DeepseekV3Model(nn.Module):
# rank=pipeline_size-1 gets the first
self.pipeline_rank = group.rank()
self.pipeline_size = group.size()
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
layers_per_rank = len(self.layers) // self.pipeline_size
extra = len(self.layers) - layers_per_rank * self.pipeline_size
if self.pipeline_rank < extra:
layers_per_rank += 1
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.layers = self.layers[: self.end_idx]