mirror of
https://github.com/ml-explore/mlx.git
synced 2025-08-21 12:06:42 +08:00
Changed math formula in Linear
Added more explanation to math formulas Changed x1, x2 reshape to support all inputs sizes
This commit is contained in:
parent
32ffcac047
commit
0296b312ac
@ -29,14 +29,13 @@ class Linear(Module):
|
|||||||
|
|
||||||
.. math::
|
.. math::
|
||||||
|
|
||||||
y = W^\top x + b
|
y = x W^\top + b
|
||||||
|
|
||||||
where:
|
where:
|
||||||
:math:`W` has shape ``[output_dims, input_dims]``.
|
where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``.
|
||||||
:math:`b` has shape ``[output_dims, ]``.
|
|
||||||
|
|
||||||
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
|
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
|
||||||
:math:`k = \frac{1}{\sqrt{input\_dims}}`.
|
where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_dims (int): The dimensionality of the input features
|
input_dims (int): The dimensionality of the input features
|
||||||
@ -79,13 +78,12 @@ class Bilinear(Module):
|
|||||||
|
|
||||||
y_i = x_1^\top W_i x_2 + b_i
|
y_i = x_1^\top W_i x_2 + b_i
|
||||||
|
|
||||||
where
|
where:
|
||||||
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``.
|
:math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``,
|
||||||
:math:`b` has shape ``[output_dims, ]``.
|
and :math:`i` indexes the output dimension.
|
||||||
:math:`i` is the index for output dimensions.
|
|
||||||
|
|
||||||
The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where
|
The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`,
|
||||||
:math:`k = \frac{1}{\sqrt{input1\_dims}}`.
|
where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input1_dims (int): The dimensionality of the input1 features
|
input1_dims (int): The dimensionality of the input1 features
|
||||||
@ -119,8 +117,8 @@ class Bilinear(Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
def __call__(self, x1: mx.array, x2: mx.array) -> mx.array:
|
||||||
x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1])
|
x1 = mx.expand_dims(x1, axis=(-2, -3))
|
||||||
x2 = x2.reshape(x2.shape[0], 1, x2.shape[1])
|
x2 = mx.expand_dims(x2, axis=(-2))
|
||||||
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
|
y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2)
|
||||||
y = mx.squeeze(x2 @ y, -2)
|
y = mx.squeeze(x2 @ y, -2)
|
||||||
if "bias" in self:
|
if "bias" in self:
|
||||||
|
Loading…
Reference in New Issue
Block a user