Changed math formula in Linear

Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
This commit is contained in:
zorea 2024-01-02 08:54:07 +02:00
parent 32ffcac047
commit 0296b312ac

View File

@ -29,14 +29,13 @@ class Linear(Module):
.. math::
y = W^\top x + b
y = x W^\top + b
where:
:math:`W` has shape ``[output_dims, input_dims]``.
:math:`b` has shape ``[output_dims, ]``.
where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
:math:`k = \frac{1}{\sqrt{input\_dims}}`.
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
Args:
input_dims (int): The dimensionality of the input features
@ -79,13 +78,12 @@ class Bilinear(Module):
y_i = x_1^\top W_i x_2 + b_i
where
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
:math:`b` has shape ``[output_dims, ]``.
:math:`i` is the index for output dimensions.
where:
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
and :math:`i` indexes the output dimension.
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
:math:`k = \frac{1}{\sqrt{input1\_dims}}`.
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
Args:
input1_dims (int): The dimensionality of the input1 features
@ -119,8 +117,8 @@ class Bilinear(Module):
)
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1])
x2 = x2.reshape(x2.shape[0], 1, x2.shape[1])
x1 = mx.expand_dims(x1, axis=(-2, -3))
x2 = mx.expand_dims(x2, axis=(-2))
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
y = mx.squeeze(x2 @ y, -2)
if "bias" in self: