From a9512835611dbaee74f567d9c6d2add34c0eb306 Mon Sep 17 00:00:00 2001 From: zorea Date: Sun, 31 Dec 2023 20:50:41 +0200 Subject: [PATCH] Added 'i' to bilinear formula --- python/mlx/nn/layers/linear.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/python/mlx/nn/layers/linear.py b/python/mlx/nn/layers/linear.py index 3aa9adb1b..d8267360a 100644 --- a/python/mlx/nn/layers/linear.py +++ b/python/mlx/nn/layers/linear.py @@ -77,11 +77,12 @@ class Bilinear(Module): .. math:: - y = x_1^\top W x_2 + b + y_i = x_1^\top W_i x_2 + b_i where :math:`W` has shape ``[output_dims, input1_dims, input2_dims]``. :math:`b` has shape ``[output_dims, ]``. + :math:`i` is the index for output dimensions. The values are initialized from :math:`\mathcal{U}(-{k}, {k})`, where :math:`k = \frac{1}{\sqrt{input1\_dims}}`.