diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index f22b2e3f..7a5bdeb1 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -378,9 +378,11 @@ class DeepseekV2Model(nn.Module): # rank=pipeline_size-1 gets the first self.pipeline_rank = group.rank() self.pipeline_size = group.size() - layers_per_rank = ( - len(self.layers) + self.pipeline_size - 1 - ) // self.pipeline_size + layers_per_rank = len(self.layers) // self.pipeline_size + extra = len(self.layers) - layers_per_rank * self.pipeline_size + if self.pipeline_rank < extra: + layers_per_rank += 1 + self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.end_idx = self.start_idx + layers_per_rank self.num_layers = layers_per_rank diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index 2df93d9f..47e17236 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -410,9 +410,10 @@ class DeepseekV3Model(nn.Module): # rank=pipeline_size-1 gets the first self.pipeline_rank = group.rank() self.pipeline_size = group.size() - layers_per_rank = ( - len(self.layers) + self.pipeline_size - 1 - ) // self.pipeline_size + layers_per_rank = len(self.layers) // self.pipeline_size + extra = len(self.layers) - layers_per_rank * self.pipeline_size + if self.pipeline_rank < extra: + layers_per_rank += 1 self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.end_idx = self.start_idx + layers_per_rank self.layers = self.layers[: self.end_idx]