diff --git a/llms/mlx_lm/models/helium.py b/llms/mlx_lm/models/helium.py index 6ca46a72..ff551bca 100644 --- a/llms/mlx_lm/models/helium.py +++ b/llms/mlx_lm/models/helium.py @@ -1,3 +1,5 @@ +# Copyright © 2025 Apple Inc. + from dataclasses import dataclass from typing import Any, Optional, Tuple diff --git a/llms/mlx_lm/models/mamba.py b/llms/mlx_lm/models/mamba.py index f2414660..93cc616e 100644 --- a/llms/mlx_lm/models/mamba.py +++ b/llms/mlx_lm/models/mamba.py @@ -1,4 +1,4 @@ -# Copyright © 2024 Apple Inc. +# Copyright © 2024-2025 Apple Inc. import math from dataclasses import dataclass @@ -123,17 +123,16 @@ class MambaBlock(nn.Module): self.intermediate_size, self.hidden_size, bias=args.use_bias ) - def ssm_step(self, x, state=None): - A = -mx.exp(self.A_log) + def ssm_step(self, x, A, state=None): D = self.D deltaBC = self.x_proj(x) - delta, B, C = mx.split( - deltaBC, - indices_or_sections=[ - self.time_step_rank, - self.time_step_rank + self.ssm_state_size, - ], - axis=-1, + delta, B, C = map( + self.mixer_norm if self.use_bcdt_rms else lambda x: x, + mx.split( + deltaBC, + [self.time_step_rank, self.time_step_rank + self.ssm_state_size], + axis=-1, + ), ) if self.use_bcdt_rms: delta, B, C = map(self.mixer_norm, (delta, B, C)) @@ -145,25 +144,40 @@ class MambaBlock(nn.Module): y = y + D * x return y, new_state - def __call__(self, x, cache): + def _process_sequence(self, x, conv_cache, state_cache): B, T, D = x.shape - if cache is None: - cache = [None, None] + xz = self.in_proj(x) + x, z = xz.split(indices_or_sections=2, axis=-1) + + conv_out, new_conv_cache = self.conv1d(x, conv_cache) + x = nn.silu(conv_out) + + A = -mx.exp(self.A_log) outputs = [] + current_state = state_cache + y = [] for t in range(T): - xt = x[:, t, :] - xz = self.in_proj(xt) - x_t, z_t = xz.split(indices_or_sections=2, axis=1) - conv_out, cache[0] = self.conv1d(mx.expand_dims(x_t, 1), cache[0]) - x_t = conv_out.squeeze(1) - x_t = nn.silu(x_t) - y_t, cache[1] = self.ssm_step(x_t, cache[1]) - z_t = nn.silu(z_t) - output_t = y_t * z_t - output_t = self.out_proj(output_t) - outputs.append(output_t) - output = mx.stack(outputs, axis=1) + y_t, current_state = self.ssm_step(x[:, t], A, current_state) + y.append(y_t) + y = mx.stack(y, axis=1) + z = self.out_proj(nn.silu(z) * y) + return z, (new_conv_cache, current_state) + + def __call__(self, x, cache): + if cache is None: + conv_cache, state_cache = None, None + else: + conv_cache, state_cache = cache[0], cache[1] + + output, (new_conv_cache, new_state_cache) = self._process_sequence( + x, conv_cache, state_cache + ) + + if isinstance(cache, MambaCache): + cache[0] = new_conv_cache + cache[1] = new_state_cache + return output diff --git a/llms/mlx_lm/models/minicpm.py b/llms/mlx_lm/models/minicpm.py index edddd583..7140c577 100644 --- a/llms/mlx_lm/models/minicpm.py +++ b/llms/mlx_lm/models/minicpm.py @@ -1,4 +1,4 @@ -# Copyright © 2023-2024 Apple Inc. +# Copyright © 2023-2025 Apple Inc. from dataclasses import dataclass from typing import Any, Dict, Optional, Tuple, Union