mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
9 Commits
simple-gem
...
d3d575cce7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d3d575cce7 | ||
|
|
8f2744dcf3 | ||
|
|
b12be4b7e0 | ||
|
|
ebfcb4a14f | ||
|
|
79175a1f35 | ||
|
|
59d4e4f61d | ||
|
|
44f776921c | ||
|
|
871ee2b9b0 | ||
|
|
6c048ab4da |
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
|||||||
@@ -846,6 +846,126 @@ class Adafactor(Optimizer):
|
|||||||
return parameter - update
|
return parameter - update
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(Optimizer):
|
||||||
|
r"""The Muon optimizer - MomentUm Orthogonalized by Newton-schulz.
|
||||||
|
|
||||||
|
Muon internally runs standard SGD-momentum, and then performs an orthogonalization post-
|
||||||
|
processing step, in which each 2D parameter's update is replaced with the nearest orthogonal
|
||||||
|
matrix. To efficiently orthogonalize each update, a Newton-Schulz iteration is used, which has
|
||||||
|
the advantage that it can be stably run in bfloat16 on the GPU.
|
||||||
|
|
||||||
|
For more details, see: https://kellerjordan.github.io/posts/muon/
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- This optimizer may not be optimal for the embedding layer, the final fully connected layer,
|
||||||
|
or any 0D/1D parameters; those should be optimized by a standard method (e.g., AdamW).
|
||||||
|
- For 4D convolutional filters, it works by flattening their last dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float or callable): The learning rate used by the internal SGD.
|
||||||
|
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||||
|
weight_decay (float, optional): The weight decay (L2 penalty). Default: ``0.01``
|
||||||
|
nesterov (bool, optional): Enables Nesterov momentum. Recommended for better performance.
|
||||||
|
Default: ``True``
|
||||||
|
ns_steps (int, optional): Number of Newton-Schulz iteration steps for orthogonalization.
|
||||||
|
Default: ``5``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
momentum: float = 0.95,
|
||||||
|
weight_decay: float = 0.01,
|
||||||
|
nesterov: bool = True,
|
||||||
|
ns_steps: int = 5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
|
self.momentum = momentum
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.nesterov = nesterov
|
||||||
|
self.ns_steps = ns_steps
|
||||||
|
|
||||||
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
|
"""Initialize optimizer state"""
|
||||||
|
state["v"] = mx.zeros_like(parameter)
|
||||||
|
|
||||||
|
def _zeropower_via_newtonschulz5(self, G, steps: int):
|
||||||
|
"""
|
||||||
|
Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
|
||||||
|
quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
|
||||||
|
of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
|
||||||
|
zero even beyond the point where the iteration no longer converges all the way to one everywhere
|
||||||
|
on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
|
||||||
|
where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
|
||||||
|
performance at all relative to UV^T, where USV^T = G is the SVD.
|
||||||
|
"""
|
||||||
|
assert G.ndim >= 2
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
X = G.astype(mx.bfloat16)
|
||||||
|
transpose_needed = G.shape[-2] > G.shape[-1]
|
||||||
|
|
||||||
|
if transpose_needed:
|
||||||
|
X = X.T
|
||||||
|
|
||||||
|
# Ensure spectral norm is at most 1
|
||||||
|
norm = mx.sqrt(mx.sum(X * X, axis=(-2, -1), keepdims=True) + 1e-7)
|
||||||
|
X = X / norm
|
||||||
|
|
||||||
|
# Perform the NS iterations
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ X.T
|
||||||
|
B = b * A + c * (A @ A)
|
||||||
|
X = a * X + B @ X
|
||||||
|
|
||||||
|
if transpose_needed:
|
||||||
|
X = X.T
|
||||||
|
return X
|
||||||
|
|
||||||
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
|
"""Performs the Muon parameter update"""
|
||||||
|
|
||||||
|
# Apply weight decay
|
||||||
|
if self.weight_decay != 0:
|
||||||
|
gradient = gradient + self.weight_decay * parameter
|
||||||
|
|
||||||
|
# Update momentum buffer
|
||||||
|
v = self.momentum * state.get("v")
|
||||||
|
v = v + (1 - self.momentum) * gradient
|
||||||
|
state["v"] = v
|
||||||
|
|
||||||
|
# Get effective gradient
|
||||||
|
if self.nesterov:
|
||||||
|
effective_grad = gradient * self.momentum + v * (1 - self.momentum)
|
||||||
|
else:
|
||||||
|
effective_grad = v
|
||||||
|
|
||||||
|
# For tensors with fewer than 2 dimensions, skip Newton-Schulz
|
||||||
|
if effective_grad.ndim < 2:
|
||||||
|
orthogonalized_grad = effective_grad
|
||||||
|
scale_factor = 1.0
|
||||||
|
else:
|
||||||
|
# Save original shape for 4D conv filters
|
||||||
|
original_shape = effective_grad.shape
|
||||||
|
reshape_needed = effective_grad.ndim > 2
|
||||||
|
|
||||||
|
if reshape_needed:
|
||||||
|
effective_grad = mx.reshape(effective_grad, (effective_grad.shape[0], -1))
|
||||||
|
|
||||||
|
# Apply Newton-Schulz orthogonalization
|
||||||
|
orthogonalized_grad = self._zeropower_via_newtonschulz5(effective_grad, steps=self.ns_steps)
|
||||||
|
|
||||||
|
# Reshape back if needed
|
||||||
|
if reshape_needed:
|
||||||
|
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
||||||
|
|
||||||
|
# Calculate scaling factor
|
||||||
|
scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||||
|
|
||||||
|
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_norm(grads, max_norm):
|
def clip_grad_norm(grads, max_norm):
|
||||||
"""Clips the global norm of the gradients.
|
"""Clips the global norm of the gradients.
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user