diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index eb4ae8b3..e5a9133f 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -161,7 +161,7 @@ class Mamba2Block(nn.Module): super().__init__() self.args = args - # Dimensions + # Same dimensions as before self.d_model = args.hidden_size self.d_state = args.state_size self.d_conv = args.conv_kernel @@ -190,51 +190,46 @@ class Mamba2Block(nn.Module): ) # Output projections - self.norm = MambaRMSNormGated( - self.d_inner, - eps=args.layer_norm_epsilon, - norm_before_gate=args.norm_before_gate - ) + self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) def __call__(self, u: mx.array, cache=None): batch_size, seq_len, _ = u.shape - if cache is None: - cache = [None, None] + # Get or initialize states from cache + if cache is None: + cache = [None, None] # [conv_state, ssm_state] + conv_state, _ = cache # We ignore ssm_state as it's not used in the parallel version + # Project input zxBCdt = self.in_proj(u) - + # Split projections z, xBC, dt = mx.split( zxBCdt, [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], axis=-1 ) - + # Process convolution - xBC, conv_state = self.conv1d(xBC, cache[0]) + xBC, conv_state = self.conv1d(xBC, conv_state) xBC = silu(xBC) - if cache is not None: - cache[0] = conv_state xBC = xBC[:, :seq_len, :] - - # Split and reshape conv output + + # Split conv output x, B, C = mx.split( xBC, [self.d_inner, self.d_inner + self.d_state * self.n_groups], axis=-1 ) - + # Reshape for SSM processing x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head)) B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1)) C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) - - # Get parameters for attention computation + + # Process with parallel attention A = -mx.exp(self.A_log) - - # Compute parallel attention y, next_state = ssd_forward_attn( x=x, dt=dt, @@ -244,17 +239,24 @@ class Mamba2Block(nn.Module): D=self.D, dt_bias=self.dt_bias, dt_min=self.args.time_step_min, - dt_max=self.args.time_step_max, + dt_max=self.args.time_step_max ) - - # Update cache - if cache is not None: - cache[1] = next_state - - # Apply normalization and output projection - y = self.norm(y, z) + + # Apply normalization based on norm_before_gate setting + if self.args.norm_before_gate: + y = self.norm(y) + y = y * nn.silu(z) + else: + y = y * nn.silu(z) + y = self.norm(y) + + # Final projection y = self.out_proj(y) - + + # Update cache + cache[0] = conv_state + cache[1] = next_state + return y