mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
replace with mx.addmm
This commit is contained in:
@@ -909,8 +909,8 @@ class Muon(Optimizer):
|
||||
# Perform the NS iterations
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = b * A + c * (A @ A)
|
||||
X = a * X + B @ X
|
||||
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
Reference in New Issue
Block a user