mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
nits and adding it to test
This commit is contained in:
@@ -933,13 +933,13 @@ class Muon(Optimizer):
|
||||
gradient = gradient + self.weight_decay * parameter
|
||||
|
||||
# Update momentum buffer
|
||||
v = self.momentum * state.get("v")
|
||||
v = self.momentum * state["v"]
|
||||
v = v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
|
||||
# Get effective gradient
|
||||
if self.nesterov:
|
||||
effective_grad = gradient * self.momentum + v * (1 - self.momentum)
|
||||
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
effective_grad = v
|
||||
|
||||
@@ -963,7 +963,8 @@ class Muon(Optimizer):
|
||||
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
||||
|
||||
# Calculate scaling factor
|
||||
scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||
|
||||
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
|
||||
|
||||
|
||||
Reference in New Issue
Block a user