diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py index d170405a..1e4fb445 100644 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -44,7 +44,8 @@ def shard_and_load(repo): allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], ) - # Lazy load and shard model + # Lazy load and shard model to figure out + # which weights we need model, _ = load_model(model_path, lazy=True, strict=False) group = mx.distributed.init(backend="mpi") @@ -62,8 +63,11 @@ def shard_and_load(repo): # Download weights for local shard download(args.model, allow_patterns=local_files) + # Load and shard the model, and load the weights tokenizer = load_tokenizer(model_path) - model, _ = load_model(model_path) + model, _ = load_model(model_path, lazy=True, strict=False) + model.model.pipeline(group) + mx.eval(model.parameters()) # Synchronize processes before generation to avoid timeout if downloading # model for the first time. diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 3136ca7b..3581fcbe 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module): self.num_layers = layers_per_rank self.layers = self.layers[: self.end_idx] self.layers[: self.start_idx] = [None] * self.start_idx + self.num_layers = len(self.layers) - self.start_idx def __call__( self, diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index e6a0dd1e..69ee1be0 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -397,12 +397,11 @@ class DeepseekV3Model(nn.Module): layers_per_rank = ( len(self.layers) + self.pipeline_size - 1 ) // self.pipeline_size - start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.end_idx = self.start_idx + layers_per_rank - self.num_layers = layers_per_rank self.layers = self.layers[: self.end_idx] self.layers[: self.start_idx] = [None] * self.start_idx + self.num_layers = len(self.layers) - self.start_idx def __call__( self,