From e1c12343b69829cfe1a8718812c2d84d07ff1549 Mon Sep 17 00:00:00 2001 From: zorea Date: Sun, 31 Dec 2023 18:40:03 +0200 Subject: [PATCH] Remove unnecessary reshape --- python/mlx/nn/layers/linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 9bb3e1cd5..3aa9adb1b 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -118,7 +118,7 @@ class Bilinear(Module): ) def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: - y = (x1 @ self.weight * x2.reshape(1, *x2.shape)).sum(-1).T + y = (x1 @ self.weight * x2).sum(-1).T if "bias" in self: y = y + self.bias return y