mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-22 04:56:41 +08:00
avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)
This commit is contained in:
parent
8e22ab1e5a
commit
40641ece39
@ -64,10 +64,10 @@ class Linear(Module):
|
|||||||
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
return f"input_dims={self.weight.shape[1]}, output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
||||||
|
|
||||||
def __call__(self, x: mx.array) -> mx.array:
|
def __call__(self, x: mx.array) -> mx.array:
|
||||||
y = x @ self.weight.T
|
x = x @ self.weight.T
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
y = y + self.bias
|
x = x + self.bias
|
||||||
return y
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Bilinear(Module):
|
class Bilinear(Module):
|
||||||
@ -103,7 +103,7 @@ class Bilinear(Module):
|
|||||||
self.weight = mx.random.uniform(
|
self.weight = mx.random.uniform(
|
||||||
low=-scale,
|
low=-scale,
|
||||||
high=scale,
|
high=scale,
|
||||||
shape=(output_dims, input1_dims, input2_dims),
|
shape=(1, output_dims, input1_dims, input2_dims),
|
||||||
)
|
)
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = mx.random.uniform(
|
self.bias = mx.random.uniform(
|
||||||
@ -114,16 +114,15 @@ class Bilinear(Module):
|
|||||||
|
|
||||||
def _extra_repr(self) -> str:
|
def _extra_repr(self) -> str:
|
||||||
return (
|
return (
|
||||||
f"input1_dims={self.weight.shape[1]}, input2_dims={self.weight.shape[2]}, "
|
f"input1_dims={self.weight.shape[2]}, input2_dims={self.weight.shape[3]}, "
|
||||||
f"output_dims={self.weight.shape[0]}, bias={'bias' in self}"
|
f"output_dims={self.weight.shape[1]}, bias={'bias' in self}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||||
x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1])
|
x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1])
|
||||||
x2 = x2.reshape(x2.shape[0], 1, x2.shape[1])
|
x2 = x2.reshape(x2.shape[0], 1, x2.shape[1])
|
||||||
w = self.weight.reshape(1, *self.weight.shape)
|
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
|
||||||
z = mx.squeeze(x1 @ w, -2).swapaxes(-1, -2)
|
y = mx.squeeze(x2 @ y, -2)
|
||||||
y = mx.squeeze(x2 @ z, -2)
|
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
y = y + self.bias
|
y = y + self.bias
|
||||||
return y
|
return y
|
||||||
|
Loading…
Reference in New Issue
Block a user