diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 9bb3e1cd5..3aa9adb1b 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -118,7 +118,7 @@ class Bilinear(Module): ) def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: - y = (x1 @ self.weight * x2.reshape(1, *x2.shape)).sum(-1).T + y = (x1 @ self.weight * x2).sum(-1).T if "bias" in self: y = y + self.bias return y