diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index a8e8e891..4468cdf1 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -87,16 +87,15 @@ def ssd(x, A, B, C, chunk_size): outputs.append(y) return mx.concatenate(outputs, axis=1), state - + class DepthWiseConv1d(nn.Module): def __init__(self, channels, kernel_size, bias=True, padding=0): super().__init__() self.channels = channels self.kernel_size = kernel_size - self.groups = channels self.padding = padding - self.weight = mx.random.normal((self.channels, kernel_size, 1)) + self.weight = mx.random.normal((channels, kernel_size, 1)) self.bias = mx.zeros((channels,)) if bias else None def __call__(self, x, cache=None): @@ -108,14 +107,23 @@ class DepthWiseConv1d(nn.Module): else: x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - y = mx.conv_general(x, self.weight, groups=self.groups) + # Adjust the weight tensor to match the input channels + if C != self.channels: + adjusted_weight = self.weight[:C, :, :] + else: + adjusted_weight = self.weight + + y = mx.conv_general(x, adjusted_weight, groups=C) if self.bias is not None: - y = y + self.bias + # Adjust the bias to match the input channels + adjusted_bias = self.bias[:C] if C != self.channels else self.bias + y = y + adjusted_bias return y, x[:, -K + 1:, :] + class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__()