mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
match muon
This commit is contained in:
@@ -903,8 +903,7 @@ class Muon(Optimizer):
|
|||||||
if transpose_needed:
|
if transpose_needed:
|
||||||
X = X.T
|
X = X.T
|
||||||
|
|
||||||
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
|
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
|
||||||
X = X / norm
|
|
||||||
|
|
||||||
for _ in range(steps):
|
for _ in range(steps):
|
||||||
A = X @ X.T
|
A = X @ X.T
|
||||||
|
|||||||
Reference in New Issue
Block a user