diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 747db9e2..888fb4fa 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -44,10 +44,6 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) -def silu(x): - return x * mx.sigmoid(x) - - class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() @@ -166,18 +162,18 @@ class Mamba2Block(nn.Module): if cache is None: cache = [None, None] - conv_state, _ = cache + conv_state, _ = cache zxBCdt = self.in_proj(u) z, xBC, dt = mx.split( - zxBCdt, + zxBCdt, [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], axis=-1 ) xBC, conv_state = self.conv1d(xBC, conv_state) - xBC = silu(xBC) + xBC =xBC * mx.sigmoid(xBC) xBC = xBC[:, :seq_len, :] x, B, C = mx.split(