diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index d8267360a..7a09d4e76 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -119,7 +119,11 @@ class Bilinear(Module): ) def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: - y = (x1 @ self.weight * x2).sum(-1).T + x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1]) + x2 = x2.reshape(x2.shape[0], 1, x2.shape[1]) + w = self.weight.reshape(1, *self.weight.shape) + z = mx.squeeze(x1 @ w, -2).swapaxes(-1, -2) + y = mx.squeeze(x2 @ z, -2) if "bias" in self: y = y + self.bias return y