mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
remove coments
This commit is contained in:
@@ -919,16 +919,13 @@ class Muon(Optimizer):
|
|||||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
"""Performs the Muon parameter update"""
|
"""Performs the Muon parameter update"""
|
||||||
|
|
||||||
# Apply weight decay
|
|
||||||
if self.weight_decay != 0:
|
if self.weight_decay != 0:
|
||||||
gradient = gradient + self.weight_decay * parameter
|
gradient = gradient + self.weight_decay * parameter
|
||||||
|
|
||||||
# Update momentum buffer
|
|
||||||
v = self.momentum * state["v"]
|
v = self.momentum * state["v"]
|
||||||
v = v + (1 - self.momentum) * gradient
|
v = v + (1 - self.momentum) * gradient
|
||||||
state["v"] = v
|
state["v"] = v
|
||||||
|
|
||||||
# Get effective gradient
|
|
||||||
if self.nesterov:
|
if self.nesterov:
|
||||||
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
|
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
|
||||||
else:
|
else:
|
||||||
@@ -948,17 +945,13 @@ class Muon(Optimizer):
|
|||||||
effective_grad, (effective_grad.shape[0], -1)
|
effective_grad, (effective_grad.shape[0], -1)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Apply Newton-Schulz orthogonalization
|
|
||||||
orthogonalized_grad = self._zeropower_via_newtonschulz5(
|
orthogonalized_grad = self._zeropower_via_newtonschulz5(
|
||||||
effective_grad, steps=self.ns_steps
|
effective_grad, steps=self.ns_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reshape back if needed
|
|
||||||
if reshape_needed:
|
if reshape_needed:
|
||||||
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
||||||
|
|
||||||
# Calculate scaling factor
|
|
||||||
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
|
||||||
scale_factor = (
|
scale_factor = (
|
||||||
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user