From b10afe3662d6598ceb1fd3ab6784ddb9f87ef124 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 10 Dec 2024 18:15:12 +0100 Subject: [PATCH] nits --- llms/mlx_lm/models/mamba2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 4468cdf1..981daa74 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -57,11 +57,12 @@ class MambaRMSNormGated(nn.Module): variance = mx.mean(hidden_states ** 2, axis=-1, keepdims=True) hidden_states = hidden_states * mx.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states - + def silu(x): return x * mx.sigmoid(x) + def ssd(x, A, B, C, chunk_size): batch, seqlen, nheads, dim = x.shape B = mx.expand_dims(B, axis=2) @@ -87,7 +88,7 @@ def ssd(x, A, B, C, chunk_size): outputs.append(y) return mx.concatenate(outputs, axis=1), state - + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): @@ -175,7 +176,7 @@ class Mamba2Block(nn.Module): # Calculate split indices and slice tensors z = proj[..., :self.d_inner] - x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.d_state)] + x_conv = proj[..., self.d_inner:self.d_inner + (self.d_inner + 2 * self.n_groups * self.d_state)] dt = proj[..., -self.n_heads:] # Process time steps