Changed math formula in Linear

Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
This commit is contained in:
zorea 2024-01-02 08:54:07 +02:00
parent 32ffcac047
commit 0296b312ac

View File

@ -29,14 +29,13 @@ class Linear(Module):
.. math:: .. math::
y = W^\top x + b y = x W^\top + b
where: where:
:math:`W` has shape ``[output_dims, input_dims]``. where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
:math:`b` has shape ``[output_dims, ]``.
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
:math:`k = \frac{1}{\sqrt{input\_dims}}`. where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
Args: Args:
input_dims (int): The dimensionality of the input features input_dims (int): The dimensionality of the input features
@ -79,13 +78,12 @@ class Bilinear(Module):
y_i = x_1^\top W_i x_2 + b_i y_i = x_1^\top W_i x_2 + b_i
where where:
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``. :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
:math:`b` has shape ``[output_dims, ]``. and :math:`i` indexes the output dimension.
:math:`i` is the index for output dimensions.
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
:math:`k = \frac{1}{\sqrt{input1\_dims}}`. where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
Args: Args:
input1_dims (int): The dimensionality of the input1 features input1_dims (int): The dimensionality of the input1 features
@ -119,8 +117,8 @@ class Bilinear(Module):
) )
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1]) x1 = mx.expand_dims(x1, axis=(-2, -3))
x2 = x2.reshape(x2.shape[0], 1, x2.shape[1]) x2 = mx.expand_dims(x2, axis=(-2))
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2) y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
y = mx.squeeze(x2 @ y, -2) y = mx.squeeze(x2 @ y, -2)
if "bias" in self: if "bias" in self: