mirror of
https://github.com/ml-explore/mlx.git
synced 2025-09-01 21:04:41 +08:00
G.astype(mx.bfloat16) to G.astype(G.dtype)
This commit is contained in:
@@ -896,7 +896,7 @@ class Muon(Optimizer):
|
||||
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
||||
assert G.ndim >= 2
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
X = G.astype(mx.bfloat16)
|
||||
X = G.astype(G.dtype)
|
||||
transpose_needed = G.shape[-2] > G.shape[-1]
|
||||
|
||||
if transpose_needed:
|
||||
|
Reference in New Issue
Block a user