mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
G.astype(mx.bfloat16) to G.astype(G.dtype)
This commit is contained in:
@@ -896,7 +896,7 @@ class Muon(Optimizer):
|
|||||||
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
||||||
assert G.ndim >= 2
|
assert G.ndim >= 2
|
||||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
X = G.astype(mx.bfloat16)
|
X = G.astype(G.dtype)
|
||||||
transpose_needed = G.shape[-2] > G.shape[-1]
|
transpose_needed = G.shape[-2] > G.shape[-1]
|
||||||
|
|
||||||
if transpose_needed:
|
if transpose_needed:
|
||||||
|
|||||||
Reference in New Issue
Block a user