diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index be6702084..9c4ca6781 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -29,14 +29,13 @@ class Linear(Module): .. math:: - y = W^\top x + b + y = x W^\top + b where: - :math:`W` has shape ``[output_dims, input_dims]``. - :math:`b` has shape ``[output_dims, ]``. + where :math:`W` has shape ``[output_dims, input_dims]`` and :math:`b` has shape ``[output_dims]``. - The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where - :math:`k = \frac{1}{\sqrt{input\_dims}}`. + The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`, + where :math:`k = \frac{1}{\sqrt{D_i}}` and :math:`D_i` is equal to ``input_dims``. Args: input_dims (int): The dimensionality of the input features @@ -79,13 +78,12 @@ class Bilinear(Module): y_i = x_1^\top W_i x_2 + b_i - where - :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``. - :math:`b` has shape ``[output_dims, ]``. - :math:`i` is the index for output dimensions. + where: + :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``, :math:`b` has shape ``[output_dims ]``, + and :math:`i` indexes the output dimension. - The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where - :math:`k = \frac{1}{\sqrt{input1\_dims}}`. + The values are initialized from the uniform distribution :math:`\mathcal{U}(-{k}, {k})`, + where :math:`k = \frac{1}{\sqrt{D_1}}` and :math:`D_1` is ``input1_dims``. Args: input1_dims (int): The dimensionality of the input1 features @@ -119,8 +117,8 @@ class Bilinear(Module): ) def __call__(self, x1: mx.array, x2: mx.array) -> mx.array: - x1 = x1.reshape(x1.shape[0], 1, 1, x1.shape[1]) - x2 = x2.reshape(x2.shape[0], 1, x2.shape[1]) + x1 = mx.expand_dims(x1, axis=(-2, -3)) + x2 = mx.expand_dims(x2, axis=(-2)) y = mx.squeeze(x1 @ self.weight, -2).swapaxes(-1, -2) y = mx.squeeze(x2 @ y, -2) if "bias" in self: