feat: Add orthogonal initializer and corresponding tests (#1651)

* feat: Add orthogonal initializer and corresponding tests

* lint

* Add acknowledgements

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2025-01-13 15:29:20 +00:00
committed by GitHub
parent 252e423e81
commit 5cc5201914
3 changed files with 73 additions and 1 deletions

View File

@@ -106,6 +106,34 @@ class TestInit(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
result = initializer(mx.zeros((1,)))
def test_orthogonal(self):
initializer = init.orthogonal(gain=1.0, dtype=mx.float32)
# Test with a square matrix
shape = (4, 4)
result = initializer(mx.zeros(shape, dtype=mx.float32))
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, mx.float32)
I = result @ result.T
eye = mx.eye(shape[0], dtype=mx.float32)
self.assertTrue(
mx.allclose(I, eye, atol=1e-5), "Orthogonal init failed on a square matrix."
)
# Test with a rectangular matrix: more rows than cols
shape = (6, 4)
result = initializer(mx.zeros(shape, dtype=mx.float32))
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, mx.float32)
I = result.T @ result
eye = mx.eye(shape[1], dtype=mx.float32)
self.assertTrue(
mx.allclose(I, eye, atol=1e-5),
"Orthogonal init failed on a rectangular matrix.",
)
if __name__ == "__main__":
unittest.main()