mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
fix deepseek sharding (#1242)
This commit is contained in:
parent
0989c073b0
commit
21d0ab6e8a
@ -44,7 +44,8 @@ def shard_and_load(repo):
|
||||
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||
)
|
||||
|
||||
# Lazy load and shard model
|
||||
# Lazy load and shard model to figure out
|
||||
# which weights we need
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
|
||||
group = mx.distributed.init(backend="mpi")
|
||||
@ -62,8 +63,11 @@ def shard_and_load(repo):
|
||||
# Download weights for local shard
|
||||
download(args.model, allow_patterns=local_files)
|
||||
|
||||
# Load and shard the model, and load the weights
|
||||
tokenizer = load_tokenizer(model_path)
|
||||
model, _ = load_model(model_path)
|
||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||
model.model.pipeline(group)
|
||||
mx.eval(model.parameters())
|
||||
|
||||
# Synchronize processes before generation to avoid timeout if downloading
|
||||
# model for the first time.
|
||||
|
@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module):
|
||||
self.num_layers = layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
self.num_layers = len(self.layers) - self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
@ -397,12 +397,11 @@ class DeepseekV3Model(nn.Module):
|
||||
layers_per_rank = (
|
||||
len(self.layers) + self.pipeline_size - 1
|
||||
) // self.pipeline_size
|
||||
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||
self.end_idx = self.start_idx + layers_per_rank
|
||||
self.num_layers = layers_per_rank
|
||||
self.layers = self.layers[: self.end_idx]
|
||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||
self.num_layers = len(self.layers) - self.start_idx
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user