match muon

This commit is contained in:
Awni Hannun
2025-07-18 06:43:11 -07:00
parent 0a8bb904d7
commit 508bd25e29

View File

@@ -903,8 +903,7 @@ class Muon(Optimizer):
if transpose_needed:
X = X.T
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
X = X / norm
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
for _ in range(steps):
A = X @ X.T