mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
match muon
This commit is contained in:
@@ -903,8 +903,7 @@ class Muon(Optimizer):
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
|
||||
X = X / norm
|
||||
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
|
||||
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
|
||||
Reference in New Issue
Block a user