replace with mx.addmm

This commit is contained in:
Goekdeniz-Guelmez
2025-07-17 19:57:18 +02:00
parent 4c0f7c713b
commit 698daee214

View File

@@ -909,8 +909,8 @@ class Muon(Optimizer):
# Perform the NS iterations # Perform the NS iterations
for _ in range(steps): for _ in range(steps):
A = X @ X.T A = X @ X.T
B = b * A + c * (A @ A) B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
X = a * X + B @ X X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
if transpose_needed: if transpose_needed:
X = X.T X = X.T