feat: Add orthogonal initializer and corresponding tests (#1651)

* feat: Add orthogonal initializer and corresponding tests

* lint

* Add acknowledgements

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan 2025-01-13 15:29:20 +00:00 committed by GitHub
parent 252e423e81
commit 5cc5201914
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 73 additions and 1 deletions

View File

@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
MLX was developed with contributions from the following individuals:
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
- Juarez Bochi: Fixed bug in cross attention.
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.

View File

@ -395,3 +395,47 @@ def sparse(
return a
return initializer
def orthogonal(
gain: float = 1.0, dtype: mx.Dtype = mx.float32
) -> Callable[[mx.array], mx.array]:
r"""An initializer that returns an orthogonal matrix.
Args:
gain (float, optional): Scaling factor for the orthogonal matrix.
Default: ``1.0``.
dtype (Dtype, optional): Data type of the array. Default: ``float32``.
Returns:
Callable[[array], array]: An initializer that returns
an orthogonal matrix with the same shape as the input.
"""
def initializer(a: mx.array) -> mx.array:
if a.ndim != 2:
raise ValueError(
f"Orthogonal initialization requires a 2D array but got"
" a {a.ndim}D array."
)
rows, cols = a.shape
n = max(rows, cols)
rmat = mx.random.normal(shape=(n, n))
# Perform QR decomposition on CPU
q, r = mx.linalg.qr(rmat, stream=mx.cpu)
# Adjust the sign of Q using the diagonal of R
d = mx.diag(r)
q = q * mx.sign(d)
# Slice Q to the desired shape
q = q[:rows, :cols]
# Scale Q by gain
q = q * gain
return q.astype(dtype)
return initializer

View File

@ -106,6 +106,34 @@ class TestInit(mlx_tests.MLXTestCase):
with self.assertRaises(ValueError):
result = initializer(mx.zeros((1,)))
def test_orthogonal(self):
initializer = init.orthogonal(gain=1.0, dtype=mx.float32)
# Test with a square matrix
shape = (4, 4)
result = initializer(mx.zeros(shape, dtype=mx.float32))
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, mx.float32)
I = result @ result.T
eye = mx.eye(shape[0], dtype=mx.float32)
self.assertTrue(
mx.allclose(I, eye, atol=1e-5), "Orthogonal init failed on a square matrix."
)
# Test with a rectangular matrix: more rows than cols
shape = (6, 4)
result = initializer(mx.zeros(shape, dtype=mx.float32))
self.assertEqual(result.shape, shape)
self.assertEqual(result.dtype, mx.float32)
I = result.T @ result
eye = mx.eye(shape[1], dtype=mx.float32)
self.assertTrue(
mx.allclose(I, eye, atol=1e-5),
"Orthogonal init failed on a rectangular matrix.",
)
if __name__ == "__main__":
unittest.main()