mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-11 22:44:38 +08:00
remove comments
This commit is contained in:
@@ -902,11 +902,9 @@ class Muon(Optimizer):
|
|||||||
if transpose_needed:
|
if transpose_needed:
|
||||||
X = X.T
|
X = X.T
|
||||||
|
|
||||||
# Ensure spectral norm is at most 1
|
|
||||||
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
|
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
|
||||||
X = X / norm
|
X = X / norm
|
||||||
|
|
||||||
# Perform the NS iterations
|
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
A = X @ X.T
|
A = X @ X.T
|
||||||
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||||
@@ -931,12 +929,10 @@ class Muon(Optimizer):
|
|||||||
else:
|
else:
|
||||||
effective_grad = v
|
effective_grad = v
|
||||||
|
|
||||||
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
|
|
||||||
if effective_grad.ndim < 2:
|
if effective_grad.ndim < 2:
|
||||||
orthogonalized_grad = effective_grad
|
orthogonalized_grad = effective_grad
|
||||||
scale_factor = 1.0
|
scale_factor = 1.0
|
||||||
else:
|
else:
|
||||||
# Save original shape for 4D conv filters
|
|
||||||
original_shape = effective_grad.shape
|
original_shape = effective_grad.shape
|
||||||
reshape_needed = effective_grad.ndim > 2
|
reshape_needed = effective_grad.ndim > 2
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user