diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 3a2bb818f..895ebec62 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -896,7 +896,7 @@ class Muon(Optimizer): def _zeropower_via_newtonschulz5(self, G, steps: int): assert G.ndim >= 2 a, b, c = (3.4445, -4.7750, 2.0315) - X = G.astype(mx.bfloat16) + X = G.astype(G.dtype) transpose_needed = G.shape[-2] > G.shape[-1] if transpose_needed: