mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 14:31:14 +08:00
Fix the implementation of the Bilinear layer (#347)
This commit is contained in:
parent
99c80a2c8b
commit
436bec9fd9
@ -101,7 +101,7 @@ class Bilinear(Module):
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(1, output_dims, input1_dims, input2_dims),
|
||||
shape=(output_dims, input2_dims, input1_dims),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.random.uniform(
|
||||
@ -111,16 +111,31 @@ class Bilinear(Module):
|
||||
)
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
out, in2, in1 = self.weight.shape
|
||||
return (
|
||||
f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, "
|
||||
f"output_dims={self.weight.shape[1]}, bias={'bias' in self}"
|
||||
f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
|
||||
f"bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||
x1 = mx.expand_dims(x1, axis=(-2, -3))
|
||||
x2 = mx.expand_dims(x2, axis=(-2))
|
||||
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
|
||||
y = mx.squeeze(x2 @ y, -2)
|
||||
# Normalize shapes
|
||||
out, in2, in1 = self.weight.shape
|
||||
xshape = x1.shape[:-1]
|
||||
x1 = x1.reshape(-1, in1)
|
||||
x2 = x2.reshape(-1, 1, in2)
|
||||
|
||||
# Perform the bilinear transformation
|
||||
w = self.weight.reshape(out * in2, in1)
|
||||
y = x1 @ w.T
|
||||
y = y.reshape(-1, out, in2).swapaxes(-2, -1)
|
||||
y = x2 @ y
|
||||
y = y.squeeze(1)
|
||||
|
||||
# Reset the shape
|
||||
y = y.reshape(*xshape, out)
|
||||
|
||||
# Apply the bias
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
|
||||
return y
|
||||
|
Loading…
Reference in New Issue
Block a user