nits and adding it to test

This commit is contained in:
Goekdeniz-Guelmez
2025-07-16 19:13:40 +02:00
parent 650c956fe6
commit df6d9e972f
2 changed files with 51 additions and 3 deletions

View File

@@ -933,13 +933,13 @@ class Muon(Optimizer):
gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state.get("v")
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
# Get effective gradient
if self.nesterov:
effective_grad = gradient * self.momentum + v * (1 - self.momentum)
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
else:
effective_grad = v
@@ -963,7 +963,8 @@ class Muon(Optimizer):
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
# Calculate scaling factor
scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor