mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-18 18:28:12 +08:00
feat: Add orthogonal initializer and corresponding tests (#1651)
* feat: Add orthogonal initializer and corresponding tests * lint * Add acknowledgements * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -395,3 +395,47 @@ def sparse(
|
||||
return a
|
||||
|
||||
return initializer
|
||||
|
||||
|
||||
def orthogonal(
|
||||
gain: float = 1.0, dtype: mx.Dtype = mx.float32
|
||||
) -> Callable[[mx.array], mx.array]:
|
||||
r"""An initializer that returns an orthogonal matrix.
|
||||
|
||||
Args:
|
||||
gain (float, optional): Scaling factor for the orthogonal matrix.
|
||||
Default: ``1.0``.
|
||||
dtype (Dtype, optional): Data type of the array. Default: ``float32``.
|
||||
|
||||
Returns:
|
||||
Callable[[array], array]: An initializer that returns
|
||||
an orthogonal matrix with the same shape as the input.
|
||||
"""
|
||||
|
||||
def initializer(a: mx.array) -> mx.array:
|
||||
if a.ndim != 2:
|
||||
raise ValueError(
|
||||
f"Orthogonal initialization requires a 2D array but got"
|
||||
" a {a.ndim}D array."
|
||||
)
|
||||
|
||||
rows, cols = a.shape
|
||||
n = max(rows, cols)
|
||||
|
||||
rmat = mx.random.normal(shape=(n, n))
|
||||
|
||||
# Perform QR decomposition on CPU
|
||||
q, r = mx.linalg.qr(rmat, stream=mx.cpu)
|
||||
|
||||
# Adjust the sign of Q using the diagonal of R
|
||||
d = mx.diag(r)
|
||||
q = q * mx.sign(d)
|
||||
|
||||
# Slice Q to the desired shape
|
||||
q = q[:rows, :cols]
|
||||
|
||||
# Scale Q by gain
|
||||
q = q * gain
|
||||
return q.astype(dtype)
|
||||
|
||||
return initializer
|
||||
|
Reference in New Issue
Block a user