mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
fix deepseek sharding (#1242)
This commit is contained in:
parent
0989c073b0
commit
21d0ab6e8a
@ -44,7 +44,8 @@ def shard_and_load(repo):
|
|||||||
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Lazy load and shard model
|
# Lazy load and shard model to figure out
|
||||||
|
# which weights we need
|
||||||
model, _ = load_model(model_path, lazy=True, strict=False)
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
|
|
||||||
group = mx.distributed.init(backend="mpi")
|
group = mx.distributed.init(backend="mpi")
|
||||||
@ -62,8 +63,11 @@ def shard_and_load(repo):
|
|||||||
# Download weights for local shard
|
# Download weights for local shard
|
||||||
download(args.model, allow_patterns=local_files)
|
download(args.model, allow_patterns=local_files)
|
||||||
|
|
||||||
|
# Load and shard the model, and load the weights
|
||||||
tokenizer = load_tokenizer(model_path)
|
tokenizer = load_tokenizer(model_path)
|
||||||
model, _ = load_model(model_path)
|
model, _ = load_model(model_path, lazy=True, strict=False)
|
||||||
|
model.model.pipeline(group)
|
||||||
|
mx.eval(model.parameters())
|
||||||
|
|
||||||
# Synchronize processes before generation to avoid timeout if downloading
|
# Synchronize processes before generation to avoid timeout if downloading
|
||||||
# model for the first time.
|
# model for the first time.
|
||||||
|
@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module):
|
|||||||
self.num_layers = layers_per_rank
|
self.num_layers = layers_per_rank
|
||||||
self.layers = self.layers[: self.end_idx]
|
self.layers = self.layers[: self.end_idx]
|
||||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
self.num_layers = len(self.layers) - self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
@ -397,12 +397,11 @@ class DeepseekV3Model(nn.Module):
|
|||||||
layers_per_rank = (
|
layers_per_rank = (
|
||||||
len(self.layers) + self.pipeline_size - 1
|
len(self.layers) + self.pipeline_size - 1
|
||||||
) // self.pipeline_size
|
) // self.pipeline_size
|
||||||
start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
|
||||||
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank
|
||||||
self.end_idx = self.start_idx + layers_per_rank
|
self.end_idx = self.start_idx + layers_per_rank
|
||||||
self.num_layers = layers_per_rank
|
|
||||||
self.layers = self.layers[: self.end_idx]
|
self.layers = self.layers[: self.end_idx]
|
||||||
self.layers[: self.start_idx] = [None] * self.start_idx
|
self.layers[: self.start_idx] = [None] * self.start_idx
|
||||||
|
self.num_layers = len(self.layers) - self.start_idx
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
|
Loading…
Reference in New Issue
Block a user