diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 7a09d4e76..be6702084 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -64,10 +64,10 @@ class Linear(Module): return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}" def __call__(self, x: mx.array) -> mx.array: - y = x @ self.weight.T + x = x @ self.weight.T if "bias" in self: - y = y + self.bias - return y + x = x + self.bias + return x class Bilinear(Module): @@ -103,7 +103,7 @@ class Bilinear(Module): self.weight = mx.random.uniform( low=-scale, high=scale, - shape=(output_dims, input1_dims, input2_dims), + shape=(1, output_dims, input1_dims, input2_dims), ) if bias: self.bias = mx.random.uniform( @@ -114,16 +114,15 @@ class Bilinear(Module): def _extra_repr(self) -> str: return ( - f"input1_dims={self.weight.shape[1]}, input2_dims={self.weight.shape[2]}, " - f"output_dims={self.weight.shape[0]}, bias={'bias' in self}" + f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, " + f"output_dims={self.weight.shape[1]}, bias={'bias' in self}" ) def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1]) x2 = x2.reshape(x2.shape[0], 1, x2.shape[1]) - w = self.weight.reshape(1, *self.weight.shape) - z = mx.squeeze(x1 @ w, -2).swapaxes(-1, -2) - y = mx.squeeze(x2 @ z, -2) + y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2) + y = mx.squeeze(x2 @ y, -2) if "bias" in self: y = y + self.bias return y