diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 71d2ab9ba..07b68cc5b 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -903,8 +903,7 @@ class Muon(Optimizer): if transpose_needed: X = X.T - norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7) - X = X / norm + X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7) for _ in range(steps): A = X @ X.T