mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
feat: Add orthogonal initializer and corresponding tests (#1651)
* feat: Add orthogonal initializer and corresponding tests * lint * Add acknowledgements * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -106,6 +106,34 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
result = initializer(mx.zeros((1,)))
|
||||
|
||||
def test_orthogonal(self):
|
||||
initializer = init.orthogonal(gain=1.0, dtype=mx.float32)
|
||||
|
||||
# Test with a square matrix
|
||||
shape = (4, 4)
|
||||
result = initializer(mx.zeros(shape, dtype=mx.float32))
|
||||
self.assertEqual(result.shape, shape)
|
||||
self.assertEqual(result.dtype, mx.float32)
|
||||
|
||||
I = result @ result.T
|
||||
eye = mx.eye(shape[0], dtype=mx.float32)
|
||||
self.assertTrue(
|
||||
mx.allclose(I, eye, atol=1e-5), "Orthogonal init failed on a square matrix."
|
||||
)
|
||||
|
||||
# Test with a rectangular matrix: more rows than cols
|
||||
shape = (6, 4)
|
||||
result = initializer(mx.zeros(shape, dtype=mx.float32))
|
||||
self.assertEqual(result.shape, shape)
|
||||
self.assertEqual(result.dtype, mx.float32)
|
||||
|
||||
I = result.T @ result
|
||||
eye = mx.eye(shape[1], dtype=mx.float32)
|
||||
self.assertTrue(
|
||||
mx.allclose(I, eye, atol=1e-5),
|
||||
"Orthogonal init failed on a rectangular matrix.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Reference in New Issue
Block a user