mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
optimizing the code for faster inference but still generates giberish
This commit is contained in:
parent
c1d9ec329c
commit
a883e39f41
@ -109,23 +109,11 @@ class DepthWiseConv1d(nn.Module):
|
||||
else:
|
||||
x = mx.pad(x, [(0, 0), (K - 1, 0), (0, 0)])
|
||||
|
||||
# Adjust the weight tensor to match the input channels
|
||||
if C != self.channels:
|
||||
adjusted_weight = self.weight[:C, :, :]
|
||||
else:
|
||||
adjusted_weight = self.weight
|
||||
|
||||
y = mx.conv_general(x, adjusted_weight, groups=C)
|
||||
|
||||
if self.bias is not None:
|
||||
# Adjust the bias to match the input channels
|
||||
adjusted_bias = self.bias[:C] if C != self.channels else self.bias
|
||||
y = y + adjusted_bias
|
||||
|
||||
y = mx.conv_general(x, self.weight, groups=C)
|
||||
y = y + self.bias
|
||||
return y, x[:, -K + 1:, :]
|
||||
|
||||
|
||||
|
||||
class Mamba2Block(nn.Module):
|
||||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
|
Loading…
Reference in New Issue
Block a user