diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 77faffcc9..88820dbd3 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -909,8 +909,8 @@ class Muon(Optimizer): # Perform the NS iterations for _ in range(steps): A = X @ X.T - B = b * A + c * (A @ A) - X = a * X + B @ X + B = mx.addmm(b * A, A, A, beta=1.0, alpha=c) + X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0) if transpose_needed: X = X.T