mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
remove coments
This commit is contained in:
@@ -919,16 +919,13 @@ class Muon(Optimizer):
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Muon parameter update"""
|
||||
|
||||
# Apply weight decay
|
||||
if self.weight_decay != 0:
|
||||
gradient = gradient + self.weight_decay * parameter
|
||||
|
||||
# Update momentum buffer
|
||||
v = self.momentum * state["v"]
|
||||
v = v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
|
||||
# Get effective gradient
|
||||
if self.nesterov:
|
||||
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
@@ -948,17 +945,13 @@ class Muon(Optimizer):
|
||||
effective_grad, (effective_grad.shape[0], -1)
|
||||
)
|
||||
|
||||
# Apply Newton-Schulz orthogonalization
|
||||
orthogonalized_grad = self._zeropower_via_newtonschulz5(
|
||||
effective_grad, steps=self.ns_steps
|
||||
)
|
||||
|
||||
# Reshape back if needed
|
||||
if reshape_needed:
|
||||
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
||||
|
||||
# Calculate scaling factor
|
||||
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||
scale_factor = (
|
||||
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user