feat: Add orthogonal initializer and corresponding tests (#1651)

* feat: Add orthogonal initializer and corresponding tests

* lint

* Add acknowledgements

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Nripesh Niketan
2025-01-13 15:29:20 +00:00
committed by GitHub
parent 252e423e81
commit 5cc5201914
3 changed files with 73 additions and 1 deletions

View File

@@ -395,3 +395,47 @@ def sparse(
return a
return initializer
def orthogonal(
gain: float = 1.0, dtype: mx.Dtype = mx.float32
) -> Callable[[mx.array], mx.array]:
r"""An initializer that returns an orthogonal matrix.
Args:
gain (float, optional): Scaling factor for the orthogonal matrix.
Default: ``1.0``.
dtype (Dtype, optional): Data type of the array. Default: ``float32``.
Returns:
Callable[[array], array]: An initializer that returns
an orthogonal matrix with the same shape as the input.
"""
def initializer(a: mx.array) -> mx.array:
if a.ndim != 2:
raise ValueError(
f"Orthogonal initialization requires a 2D array but got"
" a {a.ndim}D array."
)
rows, cols = a.shape
n = max(rows, cols)
rmat = mx.random.normal(shape=(n, n))
# Perform QR decomposition on CPU
q, r = mx.linalg.qr(rmat, stream=mx.cpu)
# Adjust the sign of Q using the diagonal of R
d = mx.diag(r)
q = q * mx.sign(d)
# Slice Q to the desired shape
q = q[:rows, :cols]
# Scale Q by gain
q = q * gain
return q.astype(dtype)
return initializer