diff --git a/README.md b/README.md index 88888ad0..e47bd598 100644 --- a/README.md +++ b/README.md @@ -45,7 +45,7 @@ Some more useful examples are listed below. ### Hugging Face -Note: You can now directly download a few converted checkpoints from the [MLX +You can directly use or download converted checkpoints from the [MLX Community](https://huggingface.co/mlx-community) organization on Hugging Face. We encourage you to join the community and [contribute new models](https://github.com/ml-explore/mlx-examples/issues/155). diff --git a/llms/README.md b/llms/README.md index e943ed69..4f7451c1 100644 --- a/llms/README.md +++ b/llms/README.md @@ -164,7 +164,7 @@ mlx_lm.convert \ ``` Models can also be converted and quantized directly in the -[mlx-my-repo]https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging +[mlx-my-repo](https://huggingface.co/spaces/mlx-community/mlx-my-repo) Hugging Face Space. ### Long Prompts and Generations diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py index d170405a..1e4fb445 100644 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -44,7 +44,8 @@ def shard_and_load(repo): allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], ) - # Lazy load and shard model + # Lazy load and shard model to figure out + # which weights we need model, _ = load_model(model_path, lazy=True, strict=False) group = mx.distributed.init(backend="mpi") @@ -62,8 +63,11 @@ def shard_and_load(repo): # Download weights for local shard download(args.model, allow_patterns=local_files) + # Load and shard the model, and load the weights tokenizer = load_tokenizer(model_path) - model, _ = load_model(model_path) + model, _ = load_model(model_path, lazy=True, strict=False) + model.model.pipeline(group) + mx.eval(model.parameters()) # Synchronize processes before generation to avoid timeout if downloading # model for the first time. diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 3136ca7b..3581fcbe 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module): self.num_layers = layers_per_rank self.layers = self.layers[: self.end_idx] self.layers[: self.start_idx] = [None] * self.start_idx + self.num_layers = len(self.layers) - self.start_idx def __call__( self, diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index e6a0dd1e..69ee1be0 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -397,12 +397,11 @@ class DeepseekV3Model(nn.Module): layers_per_rank = ( len(self.layers) + self.pipeline_size - 1 ) // self.pipeline_size - start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.end_idx = self.start_idx + layers_per_rank - self.num_layers = layers_per_rank self.layers = self.layers[: self.end_idx] self.layers[: self.start_idx] = [None] * self.start_idx + self.num_layers = len(self.layers) - self.start_idx def __call__( self, diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py index 6ca46a72..ff551bca 100644 --- a/llms/mlx_lm/models/helium.py +++ b/llms/mlx_lm/models/helium.py @@ -1,3 +1,5 @@ +# Copyright © 2025 Apple Inc. + from dataclasses import dataclass from typing import Any, Optional, Tuple diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f2414660..93cc616e 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,4 +1,4 @@ -# Copyright © 2024 Apple Inc. +# Copyright © 2024-2025 Apple Inc. import math from dataclasses import dataclass @@ -123,17 +123,16 @@ class MambaBlock(nn.Module): self.intermediate_size, self.hidden_size, bias=args.use_bias ) - def ssm_step(self, x, state=None): - A = -mx.exp(self.A_log) + def ssm_step(self, x, A, state=None): D = self.D deltaBC = self.x_proj(x) - delta, B, C = mx.split( - deltaBC, - indices_or_sections=[ - self.time_step_rank, - self.time_step_rank + self.ssm_state_size, - ], - axis=-1, + delta, B, C = map( + self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split( + deltaBC, + [self.time_step_rank, self.time_step_rank + self.ssm_state_size], + axis=-1, + ), ) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) @@ -145,25 +144,40 @@ class MambaBlock(nn.Module): y = y + D * x return y, new_state - def __call__(self, x, cache): + def _process_sequence(self, x, conv_cache, state_cache): B, T, D = x.shape - if cache is None: - cache = [None, None] + xz = self.in_proj(x) + x, z = xz.split(indices_or_sections=2, axis=-1) + + conv_out, new_conv_cache = self.conv1d(x, conv_cache) + x = nn.silu(conv_out) + + A = -mx.exp(self.A_log) outputs = [] + current_state = state_cache + y = [] for t in range(T): - xt = x[:, t, :] - xz = self.in_proj(xt) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) - conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) - x_t = conv_out.squeeze(1) - x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, cache[1]) - z_t = nn.silu(z_t) - output_t = y_t * z_t - output_t = self.out_proj(output_t) - outputs.append(output_t) - output = mx.stack(outputs, axis=1) + y_t, current_state = self.ssm_step(x[:, t], A, current_state) + y.append(y_t) + y = mx.stack(y, axis=1) + z = self.out_proj(nn.silu(z) * y) + return z, (new_conv_cache, current_state) + + def __call__(self, x, cache): + if cache is None: + conv_cache, state_cache = None, None + else: + conv_cache, state_cache = cache[0], cache[1] + + output, (new_conv_cache, new_state_cache) = self._process_sequence( + x, conv_cache, state_cache + ) + + if isinstance(cache, MambaCache): + cache[0] = new_conv_cache + cache[1] = new_state_cache + return output diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index edddd583..7140c577 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,4 +1,4 @@ -# Copyright © 2023-2024 Apple Inc. +# Copyright © 2023-2025 Apple Inc. from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 63ca58bb..bf84d066 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -140,8 +140,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 + all_losses = mx.array(0.0) + ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)