mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 01:17:26 +08:00
feat: Add orthogonal initializer and corresponding tests (#1651)
* feat: Add orthogonal initializer and corresponding tests * lint * Add acknowledgements * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
252e423e81
commit
5cc5201914
@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
|||||||
|
|
||||||
MLX was developed with contributions from the following individuals:
|
MLX was developed with contributions from the following individuals:
|
||||||
|
|
||||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||||
- Juarez Bochi: Fixed bug in cross attention.
|
- Juarez Bochi: Fixed bug in cross attention.
|
||||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||||
|
@ -395,3 +395,47 @@ def sparse(
|
|||||||
return a
|
return a
|
||||||
|
|
||||||
return initializer
|
return initializer
|
||||||
|
|
||||||
|
|
||||||
|
def orthogonal(
|
||||||
|
gain: float = 1.0, dtype: mx.Dtype = mx.float32
|
||||||
|
) -> Callable[[mx.array], mx.array]:
|
||||||
|
r"""An initializer that returns an orthogonal matrix.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
gain (float, optional): Scaling factor for the orthogonal matrix.
|
||||||
|
Default: ``1.0``.
|
||||||
|
dtype (Dtype, optional): Data type of the array. Default: ``float32``.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Callable[[array], array]: An initializer that returns
|
||||||
|
an orthogonal matrix with the same shape as the input.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def initializer(a: mx.array) -> mx.array:
|
||||||
|
if a.ndim != 2:
|
||||||
|
raise ValueError(
|
||||||
|
f"Orthogonal initialization requires a 2D array but got"
|
||||||
|
" a {a.ndim}D array."
|
||||||
|
)
|
||||||
|
|
||||||
|
rows, cols = a.shape
|
||||||
|
n = max(rows, cols)
|
||||||
|
|
||||||
|
rmat = mx.random.normal(shape=(n, n))
|
||||||
|
|
||||||
|
# Perform QR decomposition on CPU
|
||||||
|
q, r = mx.linalg.qr(rmat, stream=mx.cpu)
|
||||||
|
|
||||||
|
# Adjust the sign of Q using the diagonal of R
|
||||||
|
d = mx.diag(r)
|
||||||
|
q = q * mx.sign(d)
|
||||||
|
|
||||||
|
# Slice Q to the desired shape
|
||||||
|
q = q[:rows, :cols]
|
||||||
|
|
||||||
|
# Scale Q by gain
|
||||||
|
q = q * gain
|
||||||
|
return q.astype(dtype)
|
||||||
|
|
||||||
|
return initializer
|
||||||
|
@ -106,6 +106,34 @@ class TestInit(mlx_tests.MLXTestCase):
|
|||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
result = initializer(mx.zeros((1,)))
|
result = initializer(mx.zeros((1,)))
|
||||||
|
|
||||||
|
def test_orthogonal(self):
|
||||||
|
initializer = init.orthogonal(gain=1.0, dtype=mx.float32)
|
||||||
|
|
||||||
|
# Test with a square matrix
|
||||||
|
shape = (4, 4)
|
||||||
|
result = initializer(mx.zeros(shape, dtype=mx.float32))
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, mx.float32)
|
||||||
|
|
||||||
|
I = result @ result.T
|
||||||
|
eye = mx.eye(shape[0], dtype=mx.float32)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(I, eye, atol=1e-5), "Orthogonal init failed on a square matrix."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with a rectangular matrix: more rows than cols
|
||||||
|
shape = (6, 4)
|
||||||
|
result = initializer(mx.zeros(shape, dtype=mx.float32))
|
||||||
|
self.assertEqual(result.shape, shape)
|
||||||
|
self.assertEqual(result.dtype, mx.float32)
|
||||||
|
|
||||||
|
I = result.T @ result
|
||||||
|
eye = mx.eye(shape[1], dtype=mx.float32)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(I, eye, atol=1e-5),
|
||||||
|
"Orthogonal init failed on a rectangular matrix.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user