From 436bec9fd9a1fd009c359504305695a0c767344c Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Tue, 2 Jan 2024 16:46:18 -0800 Subject: [PATCH] Fix the implementation of the Bilinear layer (#347) --- python/mlx/nn/layers/linear.py | 29 ++++++++++++++++++++++------- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 9c4ca6781..77b340721 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -101,7 +101,7 @@ class Bilinear(Module): self.weight = mx.random.uniform( low=-scale, high=scale, - shape=(1, output_dims, input1_dims, input2_dims), + shape=(output_dims, input2_dims, input1_dims), ) if bias: self.bias = mx.random.uniform( @@ -111,16 +111,31 @@ class Bilinear(Module): ) def _extra_repr(self) -> str: + out, in2, in1 = self.weight.shape return ( - f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, " - f"output_dims={self.weight.shape[1]}, bias={'bias' in self}" + f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, " + f"bias={'bias' in self}" ) def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: - x1 = mx.expand_dims(x1, axis=(-2, -3)) - x2 = mx.expand_dims(x2, axis=(-2)) - y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2) - y = mx.squeeze(x2 @ y, -2) + # Normalize shapes + out, in2, in1 = self.weight.shape + xshape = x1.shape[:-1] + x1 = x1.reshape(-1, in1) + x2 = x2.reshape(-1, 1, in2) + + # Perform the bilinear transformation + w = self.weight.reshape(out * in2, in1) + y = x1 @ w.T + y = y.reshape(-1, out, in2).swapaxes(-2, -1) + y = x2 @ y + y = y.squeeze(1) + + # Reset the shape + y = y.reshape(*xshape, out) + + # Apply the bias if "bias" in self: y = y + self.bias + return y