fix deepseek sharding (#1242)

This commit is contained in:
Awni Hannun
2025-02-03 16:59:50 -08:00
committed by GitHub
parent 0989c073b0
commit 21d0ab6e8a
3 changed files with 8 additions and 4 deletions

View File

@@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module):
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,