Remove unnecessary reshape

This commit is contained in:
zorea 2023-12-31 18:40:03 +02:00
parent bfd7e786c3
commit e1c12343b6

View File

@ -118,7 +118,7 @@ class Bilinear(Module):
)
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
y = (x1 @ self.weight * x2.reshape(1, *x2.shape)).sum(-1).T
y = (x1 @ self.weight * x2).sum(-1).T
if "bias" in self:
y = y + self.bias
return y