mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 00:52:41 +08:00
feat: Add orthogonal initializer and corresponding tests (#1651)
* feat: Add orthogonal initializer and corresponding tests * lint * Add acknowledgements * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
252e423e81
commit
5cc5201914
@ -7,7 +7,7 @@ with a short description of your contribution(s) below. For example:
|
||||
|
||||
MLX was developed with contributions from the following individuals:
|
||||
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`.
|
||||
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
||||
- Juarez Bochi: Fixed bug in cross attention.
|
||||
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
||||
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
||||
|
@ -395,3 +395,47 @@ def sparse(
|
||||
return a
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def orthogonal(
|
||||
gain: float = 1.0, dtype: mx.Dtype = mx.float32
|
||||
) -> Callable[[mx.array], mx.array]:
|
||||
r"""An initializer that returns an orthogonal matrix.
|
||||
|
||||
Args:
|
||||
gain (float, optional): Scaling factor for the orthogonal matrix.
|
||||
Default: ``1.0``.
|
||||
dtype (Dtype, optional): Data type of the array. Default: ``float32``.
|
||||
|
||||
Returns:
|
||||
Callable[[array], array]: An initializer that returns
|
||||
an orthogonal matrix with the same shape as the input.
|
||||
"""
|
||||
|
||||
def initializer(a: mx.array) -> mx.array:
|
||||
if a.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Orthogonal initialization requires a 2D array but got"
|
||||
" a {a.ndim}D array."
|
||||
)
|
||||
|
||||
rows, cols = a.shape
|
||||
n = max(rows, cols)
|
||||
|
||||
rmat = mx.random.normal(shape=(n, n))
|
||||
|
||||
# Perform QR decomposition on CPU
|
||||
q, r = mx.linalg.qr(rmat, stream=mx.cpu)
|
||||
|
||||
# Adjust the sign of Q using the diagonal of R
|
||||
d = mx.diag(r)
|
||||
q = q * mx.sign(d)
|
||||
|
||||
# Slice Q to the desired shape
|
||||
q = q[:rows, :cols]
|
||||
|
||||
# Scale Q by gain
|
||||
q = q * gain
|
||||
return q.astype(dtype)
|
||||
|
||||
return initializer
|
||||
|
@ -106,6 +106,34 @@ class TestInit(mlx_tests.MLXTestCase):
|
||||
with self.assertRaises(ValueError):
|
||||
result = initializer(mx.zeros((1,)))
|
||||
|
||||
def test_orthogonal(self):
|
||||
initializer = init.orthogonal(gain=1.0, dtype=mx.float32)
|
||||
|
||||
# Test with a square matrix
|
||||
shape = (4, 4)
|
||||
result = initializer(mx.zeros(shape, dtype=mx.float32))
|
||||
self.assertEqual(result.shape, shape)
|
||||
self.assertEqual(result.dtype, mx.float32)
|
||||
|
||||
I = result @ result.T
|
||||
eye = mx.eye(shape[0], dtype=mx.float32)
|
||||
self.assertTrue(
|
||||
mx.allclose(I, eye, atol=1e-5), "Orthogonal init failed on a square matrix."
|
||||
)
|
||||
|
||||
# Test with a rectangular matrix: more rows than cols
|
||||
shape = (6, 4)
|
||||
result = initializer(mx.zeros(shape, dtype=mx.float32))
|
||||
self.assertEqual(result.shape, shape)
|
||||
self.assertEqual(result.dtype, mx.float32)
|
||||
|
||||
I = result.T @ result
|
||||
eye = mx.eye(shape[1], dtype=mx.float32)
|
||||
self.assertTrue(
|
||||
mx.allclose(I, eye, atol=1e-5),
|
||||
"Orthogonal init failed on a rectangular matrix.",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
Loading…
Reference in New Issue
Block a user