mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-19 23:51:14 +08:00
Fix the implementation of the Bilinear layer (#347)
This commit is contained in:
parent
99c80a2c8b
commit
436bec9fd9
@ -101,7 +101,7 @@ class Bilinear(Module):
|
|||||||
self.weight = mx.random.uniform(
|
self.weight = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
shape=(1, output_dims, input1_dims, input2_dims),
|
shape=(output_dims, input2_dims, input1_dims),
|
||||||
)
|
)
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = mx.random.uniform(
|
self.bias = mx.random.uniform(
|
||||||
@ -111,16 +111,31 @@ class Bilinear(Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _extra_repr(self) -> str:
|
def _extra_repr(self) -> str:
|
||||||
|
out, in2, in1 = self.weight.shape
|
||||||
return (
|
return (
|
||||||
f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, "
|
f"input1_dims={in1}, input2_dims={in2}, output_dims={out}, "
|
||||||
f"output_dims={self.weight.shape[1]}, bias={'bias' in self}"
|
f"bias={'bias' in self}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||||
x1 = mx.expand_dims(x1, axis=(-2, -3))
|
# Normalize shapes
|
||||||
x2 = mx.expand_dims(x2, axis=(-2))
|
out, in2, in1 = self.weight.shape
|
||||||
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
|
xshape = x1.shape[:-1]
|
||||||
y = mx.squeeze(x2 @ y, -2)
|
x1 = x1.reshape(-1, in1)
|
||||||
|
x2 = x2.reshape(-1, 1, in2)
|
||||||
|
|
||||||
|
# Perform the bilinear transformation
|
||||||
|
w = self.weight.reshape(out * in2, in1)
|
||||||
|
y = x1 @ w.T
|
||||||
|
y = y.reshape(-1, out, in2).swapaxes(-2, -1)
|
||||||
|
y = x2 @ y
|
||||||
|
y = y.squeeze(1)
|
||||||
|
|
||||||
|
# Reset the shape
|
||||||
|
y = y.reshape(*xshape, out)
|
||||||
|
|
||||||
|
# Apply the bias
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
y = y + self.bias
|
y = y + self.bias
|
||||||
|
|
||||||
return y
|
return y
|
||||||
|
Loading…
Reference in New Issue
Block a user