fix deepseek sharding (#1242)

This commit is contained in:
Awni Hannun 2025-02-03 16:59:50 -08:00 committed by GitHub
parent 0989c073b0
commit 21d0ab6e8a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 8 additions and 4 deletions

View File

@ -44,7 +44,8 @@ def shard_and_load(repo):
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
)
# Lazy load and shard model
# Lazy load and shard model to figure out
# which weights we need
model, _ = load_model(model_path, lazy=True, strict=False)
group = mx.distributed.init(backend="mpi")
@ -62,8 +63,11 @@ def shard_and_load(repo):
# Download weights for local shard
download(args.model, allow_patterns=local_files)
# Load and shard the model, and load the weights
tokenizer = load_tokenizer(model_path)
model, _ = load_model(model_path)
model, _ = load_model(model_path, lazy=True, strict=False)
model.model.pipeline(group)
mx.eval(model.parameters())
# Synchronize processes before generation to avoid timeout if downloading
# model for the first time.

View File

@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module):
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,

View File

@ -397,12 +397,11 @@ class DeepseekV3Model(nn.Module):
layers_per_rank = (
len(self.layers) + self.pipeline_size - 1
) // self.pipeline_size
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
self.end_idx = self.start_idx + layers_per_rank
self.num_layers = layers_per_rank
self.layers = self.layers[: self.end_idx]
self.layers[: self.start_idx] = [None] * self.start_idx
self.num_layers = len(self.layers) - self.start_idx
def __call__(
self,