diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 3aa9adb1b..d8267360a 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -77,11 +77,12 @@ class Bilinear(Module): .. math:: - y = x_1^\top W x_2 + b + y_i = x_1^\top W_i x_2 + b_i where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``. :math:`b` has shape ``[output_dims, ]``. + :math:`i` is the index for output dimensions. The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where :math:`k = \frac{1}{\sqrt{input1\_dims}}`.