diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index 3360a615..141ffeee 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -89,63 +89,29 @@ def ssd(x, A, B, C, chunk_size): class DepthWiseConv1d(nn.Module): - def __init__(self, in_channels, out_channels, kernel_size, bias=True, groups=None, padding=0): + def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels + self.channels = channels self.kernel_size = kernel_size self.padding = padding - self.groups = groups if groups is not None else in_channels - - assert in_channels == out_channels, "In and out channels must be same for depthwise convolution" - assert self.groups == in_channels, "Groups must be equal to in_channels for depthwise convolution" - - self.weight = mx.random.normal((in_channels, 1, kernel_size)) - self.bias = mx.zeros((out_channels,)) if bias else None + self.weight = mx.random.normal((self.channels, kernel_size, 1)) + self.bias = mx.zeros((channels,)) if bias else None - def __call__(self, x: mx.array, cache=None) -> mx.array: + def __call__(self, x, cache=None): B, L, C = x.shape - K = self.kernel_size + groups, K, _ = self.weight.shape - assert C == self.in_channels, f"Input channels {C} doesn't match expected {self.in_channels}" - if cache is not None: - # Access conv_state directly from cache[0] - if cache[0] is None: - cache[0] = mx.zeros((B, K-1, C)) + x = mx.concatenate([cache, x], axis=1) + else: + x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - x = mx.concatenate([cache[0], x], axis=1) - - outputs = [] - for c in range(C): - x_c = x[:, :, c] - x_c = mx.expand_dims(x_c, axis=1) - - w_c = self.weight[c] - if w_c.ndim == 2: - w_c = mx.expand_dims(w_c, axis=0) - elif w_c.ndim == 1: - w_c = mx.expand_dims(mx.expand_dims(w_c, axis=0), axis=0) - - y_c = mx.conv_general( - x_c, - w_c, - stride=1, - padding=0 - ) - if self.bias is not None: - y_c = y_c + self.bias[c] - - y_c = mx.squeeze(y_c, axis=1) - outputs.append(y_c) + y = mx.conv_general(x, self.weight, groups=groups) - y = mx.stack(outputs, axis=-1) + if self.bias is not None: + y = y + self.bias - # Update cache directly using cache[0] - if cache is not None: - cache[0] = x[:, -K+1:, :] if x.shape[1] >= K else x - - return y + return y, x[:, -K + 1:, :] class Mamba2Block(nn.Module): @@ -169,11 +135,9 @@ class Mamba2Block(nn.Module): # Convolution conv_dim = self.d_inner + 2 * self.d_state self.conv1d = DepthWiseConv1d( - in_channels=conv_dim, - out_channels=conv_dim, + channels=conv_dim, kernel_size=self.d_conv, - bias=args.use_conv_bias, - groups=conv_dim + bias=args.use_conv_bias ) # SSM parameters @@ -206,7 +170,9 @@ class Mamba2Block(nn.Module): dt = mx.maximum(dt, self.args.time_step_floor) # Convolution and activation - x_conv = self.conv1d(x_conv, cache=[cache[0] if cache else None]) + x_conv, conv_state = self.conv1d(x_conv, cache[0] if cache else None) + if cache is not None: + cache[0] = conv_state x_conv = silu(x_conv) # Split conv output @@ -328,6 +294,12 @@ class Model(nn.Module): logits = self.lm_head(hidden) return logits + + def sanitize(self, weights): + for k, v in weights.items(): + if "conv1d.weight" in k and v.shape[-1] != 1: + weights[k] = v.moveaxis(2, 1) + return weights def make_cache(self): return [MambaCache() for _ in range(len(self.layers))]