remove coments

This commit is contained in:
Goekdeniz-Guelmez
2025-07-17 19:53:56 +02:00
parent 3889c805da
commit 4c0f7c713b

View File

@@ -919,16 +919,13 @@ class Muon(Optimizer):
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update""" """Performs the Muon parameter update"""
# Apply weight decay
if self.weight_decay != 0: if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter gradient = gradient + self.weight_decay * parameter
# Update momentum buffer
v = self.momentum * state["v"] v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient v = v + (1 - self.momentum) * gradient
state["v"] = v state["v"] = v
# Get effective gradient
if self.nesterov: if self.nesterov:
effective_grad = gradient * (1 - self.momentum) + v * self.momentum effective_grad = gradient * (1 - self.momentum) + v * self.momentum
else: else:
@@ -948,17 +945,13 @@ class Muon(Optimizer):
effective_grad, (effective_grad.shape[0], -1) effective_grad, (effective_grad.shape[0], -1)
) )
# Apply Newton-Schulz orthogonalization
orthogonalized_grad = self._zeropower_via_newtonschulz5( orthogonalized_grad = self._zeropower_via_newtonschulz5(
effective_grad, steps=self.ns_steps effective_grad, steps=self.ns_steps
) )
# Reshape back if needed
if reshape_needed: if reshape_needed:
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape) orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
# Calculate scaling factor
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
scale_factor = ( scale_factor = (
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5 max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
) )