From d9924d08d15fbc145466f06489d106b219f12323 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Feb 2025 09:55:24 -0800 Subject: [PATCH 1/3] Fix no validation in lora (#1241) --- llms/mlx_lm/tuner/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 63ca58bb..bf84d066 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -140,8 +140,8 @@ def evaluate( loss: callable = default_loss, iterate_batches: callable = iterate_batches, ): - all_losses = 0 - ntokens = 0 + all_losses = mx.array(0.0) + ntokens = mx.array(0) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) From 0989c073b056253e5fd59334d00919ee7a9accf8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Mon, 3 Feb 2025 22:36:08 +0100 Subject: [PATCH 2/3] Optimizations for mamba1 (#1213) * added mx.einsum() operations: before: 41.293 tokens-per-sec, after: 57.822 tokens-per-sec * Fused Operations in delta, B, C = ... :. Before: 57.822 tokens-per-sec, after: 83.890 tokens-per-sec * Pre-computing A_log. After: 83.890 tokens-per-sec, before: 85.848 tokens-per-sec * Update MambaBlock, Batched Input Processing, Improved Cache Handling, Pre-computed Constants, Cleaner State Management, Explicit Return Values:. Before: 82.442 tokens-per-sec, after: 129.130 tokens-per-sec. * cleaning up and adding apple copyright to helium modelfile * update Copyright to this year * nits + even faster --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/models/helium.py | 2 ++ llms/mlx_lm/models/mamba.py | 64 +++++++++++++++++++++-------------- llms/mlx_lm/models/minicpm.py | 2 +- 3 files changed, 42 insertions(+), 26 deletions(-) diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py index 6ca46a72..ff551bca 100644 --- a/llms/mlx_lm/models/helium.py +++ b/llms/mlx_lm/models/helium.py @@ -1,3 +1,5 @@ +# Copyright © 2025 Apple Inc. + from dataclasses import dataclass from typing import Any, Optional, Tuple diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f2414660..93cc616e 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,4 +1,4 @@ -# Copyright © 2024 Apple Inc. +# Copyright © 2024-2025 Apple Inc. import math from dataclasses import dataclass @@ -123,17 +123,16 @@ class MambaBlock(nn.Module): self.intermediate_size, self.hidden_size, bias=args.use_bias ) - def ssm_step(self, x, state=None): - A = -mx.exp(self.A_log) + def ssm_step(self, x, A, state=None): D = self.D deltaBC = self.x_proj(x) - delta, B, C = mx.split( - deltaBC, - indices_or_sections=[ - self.time_step_rank, - self.time_step_rank + self.ssm_state_size, - ], - axis=-1, + delta, B, C = map( + self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split( + deltaBC, + [self.time_step_rank, self.time_step_rank + self.ssm_state_size], + axis=-1, + ), ) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) @@ -145,25 +144,40 @@ class MambaBlock(nn.Module): y = y + D * x return y, new_state - def __call__(self, x, cache): + def _process_sequence(self, x, conv_cache, state_cache): B, T, D = x.shape - if cache is None: - cache = [None, None] + xz = self.in_proj(x) + x, z = xz.split(indices_or_sections=2, axis=-1) + + conv_out, new_conv_cache = self.conv1d(x, conv_cache) + x = nn.silu(conv_out) + + A = -mx.exp(self.A_log) outputs = [] + current_state = state_cache + y = [] for t in range(T): - xt = x[:, t, :] - xz = self.in_proj(xt) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) - conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) - x_t = conv_out.squeeze(1) - x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, cache[1]) - z_t = nn.silu(z_t) - output_t = y_t * z_t - output_t = self.out_proj(output_t) - outputs.append(output_t) - output = mx.stack(outputs, axis=1) + y_t, current_state = self.ssm_step(x[:, t], A, current_state) + y.append(y_t) + y = mx.stack(y, axis=1) + z = self.out_proj(nn.silu(z) * y) + return z, (new_conv_cache, current_state) + + def __call__(self, x, cache): + if cache is None: + conv_cache, state_cache = None, None + else: + conv_cache, state_cache = cache[0], cache[1] + + output, (new_conv_cache, new_state_cache) = self._process_sequence( + x, conv_cache, state_cache + ) + + if isinstance(cache, MambaCache): + cache[0] = new_conv_cache + cache[1] = new_state_cache + return output diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index edddd583..7140c577 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,4 +1,4 @@ -# Copyright © 2023-2024 Apple Inc. +# Copyright © 2023-2025 Apple Inc. from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union From 21d0ab6e8abd3ecc549c7db526b2097fd9089352 Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Mon, 3 Feb 2025 16:59:50 -0800 Subject: [PATCH 3/3] fix deepseek sharding (#1242) --- llms/mlx_lm/examples/pipeline_generate.py | 8 ++++++-- llms/mlx_lm/models/deepseek_v2.py | 1 + llms/mlx_lm/models/deepseek_v3.py | 3 +-- 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/llms/mlx_lm/examples/pipeline_generate.py b/llms/mlx_lm/examples/pipeline_generate.py index d170405a..1e4fb445 100644 --- a/llms/mlx_lm/examples/pipeline_generate.py +++ b/llms/mlx_lm/examples/pipeline_generate.py @@ -44,7 +44,8 @@ def shard_and_load(repo): allow_patterns=["*.json", "*.py", "tokenizer.model", "*.tiktoken", "*.txt"], ) - # Lazy load and shard model + # Lazy load and shard model to figure out + # which weights we need model, _ = load_model(model_path, lazy=True, strict=False) group = mx.distributed.init(backend="mpi") @@ -62,8 +63,11 @@ def shard_and_load(repo): # Download weights for local shard download(args.model, allow_patterns=local_files) + # Load and shard the model, and load the weights tokenizer = load_tokenizer(model_path) - model, _ = load_model(model_path) + model, _ = load_model(model_path, lazy=True, strict=False) + model.model.pipeline(group) + mx.eval(model.parameters()) # Synchronize processes before generation to avoid timeout if downloading # model for the first time. diff --git a/llms/mlx_lm/models/deepseek_v2.py b/llms/mlx_lm/models/deepseek_v2.py index 3136ca7b..3581fcbe 100644 --- a/llms/mlx_lm/models/deepseek_v2.py +++ b/llms/mlx_lm/models/deepseek_v2.py @@ -386,6 +386,7 @@ class DeepseekV2Model(nn.Module): self.num_layers = layers_per_rank self.layers = self.layers[: self.end_idx] self.layers[: self.start_idx] = [None] * self.start_idx + self.num_layers = len(self.layers) - self.start_idx def __call__( self, diff --git a/llms/mlx_lm/models/deepseek_v3.py b/llms/mlx_lm/models/deepseek_v3.py index e6a0dd1e..69ee1be0 100644 --- a/llms/mlx_lm/models/deepseek_v3.py +++ b/llms/mlx_lm/models/deepseek_v3.py @@ -397,12 +397,11 @@ class DeepseekV3Model(nn.Module): layers_per_rank = ( len(self.layers) + self.pipeline_size - 1 ) // self.pipeline_size - start = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.start_idx = (self.pipeline_size - self.pipeline_rank - 1) * layers_per_rank self.end_idx = self.start_idx + layers_per_rank - self.num_layers = layers_per_rank self.layers = self.layers[: self.end_idx] self.layers[: self.start_idx] = [None] * self.start_idx + self.num_layers = len(self.layers) - self.start_idx def __call__( self,