diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 6adc6469..7c044dba 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -42,6 +42,29 @@ class ModelArgs(BaseModelArgs): self.time_step_rank = math.ceil(self.hidden_size / 16) +class DepthWiseConv1d(nn.Module): + def __init__(self, channels, kernel_size, bias=True, padding=0): + super().__init__() + self.channels = channels + self.kernel_size = kernel_size + self.padding = padding + self.weight = mx.random.normal((channels, kernel_size, 1)) + self.bias = mx.zeros((channels,)) if bias else None + + def __call__(self, x, cache=None): + B, L, C = x.shape + _, K, _ = self.weight.shape + + if cache is not None: + x = mx.concatenate([cache, x], axis=1) + else: + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) + + y = mx.conv_general(x, self.weight, groups=C) + y = y + self.bias + return y, x[:, -K + 1:, :] + + def ssd_forward_attn( x: mx.array, dt: mx.array, @@ -144,13 +167,11 @@ class Mamba2Block(nn.Module): self.A_log = mx.random.normal((self.n_heads,)) * args.initializer_range self.D = mx.random.normal((self.n_heads,)) * args.initializer_range - conv_channels = self.d_inner + 2 * self.n_groups * self.d_state - self.conv1d = nn.Conv1d( - in_channels=conv_channels, - out_channels=conv_channels, + self.conv1d = DepthWiseConv1d( + channels=self.d_inner + 2 * self.n_groups * self.d_state, kernel_size=self.d_conv, - groups=conv_channels, - padding=self.d_conv - 1, + bias=args.use_conv_bias, + padding=self.d_conv-1 ) self.norm = nn.RMSNorm(self.d_inner, eps=args.layer_norm_epsilon) @@ -172,36 +193,10 @@ class Mamba2Block(nn.Module): axis=-1 ) - # Handle convolution with caching - xBC = mx.swapaxes(xBC, 1, 2) # [B, L, C] -> [B, C, L] - - if conv_state is not None and seq_len > 0: - # Concatenate cached state with current input - xBC_with_cache = mx.concatenate([conv_state, xBC], axis=2) - elif seq_len > 0: - # For the first call, pad with zeros - padding = mx.zeros((batch_size, xBC.shape[1], self.d_conv - 1)) - xBC_with_cache = mx.concatenate([padding, xBC], axis=2) - else: - xBC_with_cache = conv_state if conv_state is not None else mx.zeros((batch_size, xBC.shape[1], 0)) - - # Save state for next iteration - if seq_len > 0: - next_conv_state = xBC_with_cache[:, :, -(self.d_conv - 1):] - else: - next_conv_state = conv_state - - # Apply regular convolution using nn.Conv1d - if seq_len > 0: - # Use the standard Conv1d module for the actual computation - xBC_conv = self.conv1d(xBC_with_cache) - xBC = xBC_conv[:, :, -seq_len:] # Take only the relevant output positions - xBC = mx.swapaxes(xBC, 1, 2) # [B, C, L] -> [B, L, C] - xBC = xBC * mx.sigmoid(xBC) - else: - # Handle empty sequence case - xBC = mx.swapaxes(xBC, 1, 2) # [B, C, L] -> [B, L, C] - + xBC, conv_state = self.conv1d(xBC, conv_state) + xBC = xBC * mx.sigmoid(xBC) + xBC = xBC[:, :seq_len, :] + x, B, C = mx.split( xBC, [self.d_inner, self.d_inner + self.d_state * self.n_groups], @@ -212,6 +207,7 @@ class Mamba2Block(nn.Module): B = mx.reshape(B, (batch_size, seq_len, self.n_groups, -1)) C = mx.reshape(C, (batch_size, seq_len, self.n_groups, -1)) + A = -mx.exp(self.A_log) y, next_ssm_state = ssd_forward_attn( x=x, dt=dt, @@ -234,7 +230,7 @@ class Mamba2Block(nn.Module): y = self.out_proj(y) - cache[0] = next_conv_state + cache[0] = conv_state cache[1] = next_ssm_state return y