mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 20:25:22 +08:00
inference on codestral works but is giberish
This commit is contained in:
parent
ddad2105ef
commit
9f8a6a3509
@ -94,9 +94,8 @@ class DepthWiseConv1d(nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.channels = channels
|
self.channels = channels
|
||||||
self.kernel_size = kernel_size
|
self.kernel_size = kernel_size
|
||||||
self.groups = channels
|
|
||||||
self.padding = padding
|
self.padding = padding
|
||||||
self.weight = mx.random.normal((self.channels, kernel_size, 1))
|
self.weight = mx.random.normal((channels, kernel_size, 1))
|
||||||
self.bias = mx.zeros((channels,)) if bias else None
|
self.bias = mx.zeros((channels,)) if bias else None
|
||||||
|
|
||||||
def __call__(self, x, cache=None):
|
def __call__(self, x, cache=None):
|
||||||
@ -108,14 +107,23 @@ class DepthWiseConv1d(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||||
|
|
||||||
y = mx.conv_general(x, self.weight, groups=self.groups)
|
# Adjust the weight tensor to match the input channels
|
||||||
|
if C != self.channels:
|
||||||
|
adjusted_weight = self.weight[:C, :, :]
|
||||||
|
else:
|
||||||
|
adjusted_weight = self.weight
|
||||||
|
|
||||||
|
y = mx.conv_general(x, adjusted_weight, groups=C)
|
||||||
|
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
y = y + self.bias
|
# Adjust the bias to match the input channels
|
||||||
|
adjusted_bias = self.bias[:C] if C != self.channels else self.bias
|
||||||
|
y = y + adjusted_bias
|
||||||
|
|
||||||
return y, x[:, -K + 1:, :]
|
return y, x[:, -K + 1:, :]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class Mamba2Block(nn.Module):
|
class Mamba2Block(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
Loading…
Reference in New Issue
Block a user