diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 88820dbd3..465c49a20 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -902,11 +902,9 @@ class Muon(Optimizer): if transpose_needed: X = X.T - # Ensure spectral norm is at most 1 norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7) X = X / norm - # Perform the NS iterations for _ in range(steps): A = X @ X.T B = mx.addmm(b * A, A, A, beta=1.0, alpha=c) @@ -931,12 +929,10 @@ class Muon(Optimizer): else: effective_grad = v - # For tensors with fewer than 2 dimensions, skip Newton-Schulz if effective_grad.ndim < 2: orthogonalized_grad = effective_grad scale_factor = 1.0 else: - # Save original shape for 4D conv filters original_shape = effective_grad.shape reshape_needed = effective_grad.ndim > 2