diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index b85d3667..0ed62287 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -109,23 +109,11 @@ class DepthWiseConv1d(nn.Module): else: x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - # Adjust the weight tensor to match the input channels - if C != self.channels: - adjusted_weight = self.weight[:C, :, :] - else: - adjusted_weight = self.weight - - y = mx.conv_general(x, adjusted_weight, groups=C) - - if self.bias is not None: - # Adjust the bias to match the input channels - adjusted_bias = self.bias[:C] if C != self.channels else self.bias - y = y + adjusted_bias - + y = mx.conv_general(x, self.weight, groups=C) + y = y + self.bias return y, x[:, -K + 1:, :] - class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__()