From a883e39f41c47936ff2e2f10ce3940c40ff5338e Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Thu, 12 Dec 2024 21:08:33 +0100 Subject: [PATCH] optimizing the code for faster inference but still generates giberish --- llms/mlx_lm/models/mamba2.py | 16 ++-------------- 1 file changed, 2 insertions(+), 14 deletions(-) diff --git a/llms/mlx_lm/models/mamba2.py b/llms/mlx_lm/models/mamba2.py index b85d3667..0ed62287 100644 --- a/llms/mlx_lm/models/mamba2.py +++ b/llms/mlx_lm/models/mamba2.py @@ -109,23 +109,11 @@ class DepthWiseConv1d(nn.Module): else: x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)]) - # Adjust the weight tensor to match the input channels - if C != self.channels: - adjusted_weight = self.weight[:C, :, :] - else: - adjusted_weight = self.weight - - y = mx.conv_general(x, adjusted_weight, groups=C) - - if self.bias is not None: - # Adjust the bias to match the input channels - adjusted_bias = self.bias[:C] if C != self.channels else self.bias - y = y + adjusted_bias - + y = mx.conv_general(x, self.weight, groups=C) + y = y + self.bias return y, x[:, -K + 1:, :] - class Mamba2Block(nn.Module): def __init__(self, args: ModelArgs): super().__init__()