Fix the implementation of the Bilinear layer (#347)

This commit is contained in:
Angelos Katharopoulos 2024-01-02 16:46:18 -08:00 committed by GitHub
parent 99c80a2c8b
commit 436bec9fd9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -101,7 +101,7 @@ class Bilinear(Module):
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(1, output_dims, input1_dims, input2_dims),
shape=(output_dims, input2_dims, input1_dims),
)
if bias:
self.bias = mx.random.uniform(
@ -111,16 +111,31 @@ class Bilinear(Module):
)
def _extra_repr(self) -> str:
out, in2, in1 = self.weight.shape
return (
f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, "
f"output_dims={self.weight.shape[1]}, bias={'bias' in self}"
f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
f"bias={'bias' in self}"
)
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
x1 = mx.expand_dims(x1, axis=(-2, -3))
x2 = mx.expand_dims(x2, axis=(-2))
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
y = mx.squeeze(x2 @ y, -2)
# Normalize shapes
out, in2, in1 = self.weight.shape
xshape = x1.shape[:-1]
x1 = x1.reshape(-1, in1)
x2 = x2.reshape(-1, 1, in2)
# Perform the bilinear transformation
w = self.weight.reshape(out * in2, in1)
y = x1 @ w.T
y = y.reshape(-1, out, in2).swapaxes(-2, -1)
y = x2 @ y
y = y.squeeze(1)
# Reset the shape
y = y.reshape(*xshape, out)
# Apply the bias
if "bias" in self:
y = y + self.bias
return y