diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 432ab994..747db9e2 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -134,8 +134,6 @@ class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__() self.args = args - - # Same dimensions as before self.d_model = args.hidden_size self.d_state = args.state_size self.d_conv = args.conv_kernel @@ -146,16 +144,13 @@ class Mamba2Block(nn.Module): self.d_head = self.d_inner // self.n_heads self.chunk_size = args.chunk_size - # Input projection d_in_proj = 2 * self.d_inner + 2 * self.n_groups * self.d_state + self.n_heads self.in_proj = nn.Linear(self.d_model, d_in_proj, bias=args.use_bias) - # Parameters self.dt_bias = mx.random.normal((self.n_heads,)) * args.initializer_range self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range self.D = mx.random.normal((self.n_heads,)) * args.initializer_range - # Convolution self.conv1d = DepthWiseConv1d( channels=self.d_inner + 2 * self.n_groups * self.d_state, kernel_size=self.d_conv, @@ -163,46 +158,38 @@ class Mamba2Block(nn.Module): padding=self.d_conv-1 ) - # Output projections self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=args.use_bias) def __call__(self, u: mx.array, cache=None): batch_size, seq_len, _ = u.shape - # Get or initialize states from cache if cache is None: - cache = [None, None] # [conv_state, ssm_state] - conv_state, _ = cache # We ignore ssm_state as it's not used in the parallel version + cache = [None, None] + conv_state, _ = cache - # Project input zxBCdt = self.in_proj(u) - # Split projections z, xBC, dt = mx.split( zxBCdt, [self.d_inner, 2 * self.d_inner + 2 * self.n_groups * self.d_state], axis=-1 ) - # Process convolution xBC, conv_state = self.conv1d(xBC, conv_state) xBC = silu(xBC) xBC = xBC[:, :seq_len, :] - # Split conv output x, B, C = mx.split( xBC, [self.d_inner, self.d_inner + self.d_state * self.n_groups], axis=-1 ) - # Reshape for SSM processing x = mx.reshape(x, (batch_size, seq_len, self.n_heads, self.d_head)) B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1)) C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) - # Process with parallel attention A = -mx.exp(self.A_log) y, next_state = ssd_forward_attn( x=x, @@ -216,7 +203,6 @@ class Mamba2Block(nn.Module): dt_max=self.args.time_step_max ) - # Apply normalization based on norm_before_gate setting if self.args.norm_before_gate: y = self.norm(y) y = y * nn.silu(z) @@ -224,10 +210,8 @@ class Mamba2Block(nn.Module): y = y * nn.silu(z) y = self.norm(y) - # Final projection y = self.out_proj(y) - # Update cache cache[0] = conv_state cache[1] = next_state @@ -242,8 +226,8 @@ class ResidualBlock(nn.Module): self.norm = nn.RMSNorm(args.hidden_size) def __call__(self, x: mx.array, cache): - # if self.residual_in_fp32: - # x = x.astype(mx.float32) + if self.residual_in_fp32: + x = x.astype(mx.float32) normed = self.norm(x) output = self.mixer(normed, cache) return output + x