mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 20:46:46 +08:00
avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)
This commit is contained in:
parent
8e22ab1e5a
commit
40641ece39
@ -64,10 +64,10 @@ class Linear(Module):
|
||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
|
||||
def __call__(self, x: mx.array) -> mx.array:
|
||||
y = x @ self.weight.T
|
||||
x = x @ self.weight.T
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
x = x + self.bias
|
||||
return x
|
||||
|
||||
|
||||
class Bilinear(Module):
|
||||
@ -103,7 +103,7 @@ class Bilinear(Module):
|
||||
self.weight = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(output_dims, input1_dims, input2_dims),
|
||||
shape=(1, output_dims, input1_dims, input2_dims),
|
||||
)
|
||||
if bias:
|
||||
self.bias = mx.random.uniform(
|
||||
@ -114,16 +114,15 @@ class Bilinear(Module):
|
||||
|
||||
def _extra_repr(self) -> str:
|
||||
return (
|
||||
f"input1_dims={self.weight.shape[1]}, input2_dims={self.weight.shape[2]}, "
|
||||
f"output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||
f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, "
|
||||
f"output_dims={self.weight.shape[1]}, bias={'bias' in self}"
|
||||
)
|
||||
|
||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||
x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1])
|
||||
x2 = x2.reshape(x2.shape[0], 1, x2.shape[1])
|
||||
w = self.weight.reshape(1, *self.weight.shape)
|
||||
z = mx.squeeze(x1 @ w, -2).swapaxes(-1, -2)
|
||||
y = mx.squeeze(x2 @ z, -2)
|
||||
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
|
||||
y = mx.squeeze(x2 @ y, -2)
|
||||
if "bias" in self:
|
||||
y = y + self.bias
|
||||
return y
|
||||
|
Loading…
Reference in New Issue
Block a user