remove coments

This commit is contained in:
Goekdeniz-Guelmez
2025-07-17 19:53:56 +02:00
parent 3889c805da
commit 4c0f7c713b

View File

@@ -919,16 +919,13 @@ class Muon(Optimizer):
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
# Apply weight decay
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
# Get effective gradient
if self.nesterov:
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
else:
@@ -948,17 +945,13 @@ class Muon(Optimizer):
effective_grad, (effective_grad.shape[0], -1)
)
# Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5(
effective_grad, steps=self.ns_steps
)
# Reshape back if needed
if reshape_needed:
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
# Calculate scaling factor
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = (
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
)