G.astype(mx.bfloat16) to G.astype(G.dtype)

This commit is contained in:
Goekdeniz-Guelmez
2025-07-17 19:49:26 +02:00
parent 7f39e9c299
commit 060404d862

View File

@@ -896,7 +896,7 @@ class Muon(Optimizer):
def _zeropower_via_newtonschulz5(self, G, steps: int):
assert G.ndim >= 2
a, b, c = (3.4445, -4.7750, 2.0315)
X = G.astype(mx.bfloat16)
X = G.astype(G.dtype)
transpose_needed = G.shape[-2] > G.shape[-1]
if transpose_needed: