Feature expand nn linear (#315)

* Added an identity and bilinear layers
Added a reset_parameters option
Added normal init for bias

* pre-commit run

* add type hints for parameters and the return type
change Bilinear math to x_1 and x_2
change __call__ arguments to x and y instead of input and output
add explanation to the Initialization

* Remove unnecessary reshape

* Added 'i' to bilinear formula

* Changed bilinear computation to two matrix multiplications

* avoid saving intermediate results, kept y in bilinear for better clarity (can be replaced with x1)

* Changed math formula in Linear
Added more explanation to math formulas
Changed x1, x2 reshape to support all inputs sizes
This commit is contained in:
Asaf Zorea
2024-01-02 16:08:53 +02:00
committed by GitHub
parent 44c1ce5e6a
commit 295ce9db09
3 changed files with 103 additions and 9 deletions

View File

@@ -12,12 +12,25 @@ from mlx.utils import tree_flatten, tree_map, tree_unflatten
class TestNN(mlx_tests.MLXTestCase):
def test_identity(self):
inputs = mx.zeros((10, 4))
layer = nn.Identity()
outputs = layer(inputs)
self.assertEqual(tuple(inputs.shape), tuple(outputs.shape))
def test_linear(self):
inputs = mx.zeros((10, 4))
layer = nn.Linear(input_dims=4, output_dims=8)
outputs = layer(inputs)
self.assertEqual(tuple(outputs.shape), (10, 8))
def test_bilinear(self):
inputs1 = mx.zeros((10, 2))
inputs2 = mx.zeros((10, 4))
layer = nn.Bilinear(input1_dims=2, input2_dims=4, output_dims=6)
outputs = layer(inputs1, inputs2)
self.assertEqual(tuple(outputs.shape), (10, 6))
def test_cross_entropy(self):
logits = mx.array([[0.0, -float("inf")], [-float("inf"), 0.0]])
targets = mx.array([0, 1])