avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)

This commit is contained in:
zorea 2024-01-01 18:06:47 +02:00
parent 8e22ab1e5a
commit 40641ece39

View File

@ -64,10 +64,10 @@ class Linear(Module):
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
def __call__(self, x: mx.array) -> mx.array:
y = x @ self.weight.T
x = x @ self.weight.T
if "bias" in self:
y = y + self.bias
return y
x = x + self.bias
return x
class Bilinear(Module):
@ -103,7 +103,7 @@ class Bilinear(Module):
self.weight = mx.random.uniform(
low=-scale,
high=scale,
shape=(output_dims, input1_dims, input2_dims),
shape=(1, output_dims, input1_dims, input2_dims),
)
if bias:
self.bias = mx.random.uniform(
@ -114,16 +114,15 @@ class Bilinear(Module):
def _extra_repr(self) -> str:
return (
f"input1_dims={self.weight.shape[1]}, input2_dims={self.weight.shape[2]}, "
f"output_dims={self.weight.shape[0]}, bias={'bias' in self}"
f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, "
f"output_dims={self.weight.shape[1]}, bias={'bias' in self}"
)
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1])
x2 = x2.reshape(x2.shape[0], 1, x2.shape[1])
w = self.weight.reshape(1, *self.weight.shape)
z = mx.squeeze(x1 @ w, -2).swapaxes(-1, -2)
y = mx.squeeze(x2 @ z, -2)
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
y = mx.squeeze(x2 @ y, -2)
if "bias" in self:
y = y + self.bias
return y