mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 13:07:51 +08:00
Changed bilinear computation to two matrix multiplications
This commit is contained in:
parent
a951283561
commit
8e22ab1e5a
@ -119,7 +119,11 @@ class Bilinear(Module):
|
||||
)
|
||||
|
||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||
y = (x1 @ self.weight * x2).sum(-1).T
|
||||
x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1])
|
||||
x2 = x2.reshape(x2.shape[0], 1, x2.shape[1])
|
||||
w = self.weight.reshape(1, *self.weight.shape)
|
||||
z = mx.squeeze(x1 @ w, -2).swapaxes(-1, -2)
|
||||
y = mx.squeeze(x2 @ z, -2)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
Loading…
Reference in New Issue
Block a user