mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
replace with mx.addmm
This commit is contained in:
@@ -909,8 +909,8 @@ class Muon(Optimizer):
|
|||||||
# Perform the NS iterations
|
# Perform the NS iterations
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
A = X @ X.T
|
A = X @ X.T
|
||||||
B = b * A + c * (A @ A)
|
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||||
X = a * X + B @ X
|
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||||
|
|
||||||
if transpose_needed:
|
if transpose_needed:
|
||||||
X = X.T
|
X = X.T
|
||||||
|
|||||||
Reference in New Issue
Block a user