inference on codestral works but is giberish

This commit is contained in:
Goekdeniz-Guelmez 2024-12-10 17:34:44 +01:00
parent ddad2105ef
commit 9f8a6a3509

View File

@ -87,16 +87,15 @@ def ssd(x, A, B, C, chunk_size):
outputs.append(y)
return mx.concatenate(outputs, axis=1), state
class DepthWiseConv1d(nn.Module):
def __init__(self, channels, kernel_size, bias=True, padding=0):
super().__init__()
self.channels = channels
self.kernel_size = kernel_size
self.groups = channels
self.padding = padding
self.weight = mx.random.normal((self.channels, kernel_size, 1))
self.weight = mx.random.normal((channels, kernel_size, 1))
self.bias = mx.zeros((channels,)) if bias else None
def __call__(self, x, cache=None):
@ -108,14 +107,23 @@ class DepthWiseConv1d(nn.Module):
else:
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
y = mx.conv_general(x, self.weight, groups=self.groups)
# Adjust the weight tensor to match the input channels
if C != self.channels:
adjusted_weight = self.weight[:C, :, :]
else:
adjusted_weight = self.weight
y = mx.conv_general(x, adjusted_weight, groups=C)
if self.bias is not None:
y = y + self.bias
# Adjust the bias to match the input channels
adjusted_bias = self.bias[:C] if C != self.channels else self.bias
y = y + adjusted_bias
return y, x[:, -K + 1:, :]
class Mamba2Block(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__()