mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 21:04:41 +08:00
remove comments
This commit is contained in:
@@ -902,11 +902,9 @@ class Muon(Optimizer):
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
# Ensure spectral norm is at most 1
|
||||
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
|
||||
X = X / norm
|
||||
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||
@@ -931,12 +929,10 @@ class Muon(Optimizer):
|
||||
else:
|
||||
effective_grad = v
|
||||
|
||||
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
|
||||
if effective_grad.ndim < 2:
|
||||
orthogonalized_grad = effective_grad
|
||||
scale_factor = 1.0
|
||||
else:
|
||||
# Save original shape for 4D conv filters
|
||||
original_shape = effective_grad.shape
|
||||
reshape_needed = effective_grad.ndim > 2
|
||||
|
||||
|
Reference in New Issue
Block a user