mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
22 Commits
v0.29.2
...
0a8bb904d7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a8bb904d7 | ||
|
|
c535d8c1b5 | ||
|
|
4b3d7634cd | ||
|
|
516d172ba5 | ||
|
|
698daee214 | ||
|
|
4c0f7c713b | ||
|
|
3889c805da | ||
|
|
060404d862 | ||
|
|
7f39e9c299 | ||
|
|
baad6e392b | ||
|
|
784e0716fe | ||
|
|
df6d9e972f | ||
|
|
650c956fe6 | ||
|
|
d3d575cce7 | ||
|
|
8f2744dcf3 | ||
|
|
b12be4b7e0 | ||
|
|
ebfcb4a14f | ||
|
|
79175a1f35 | ||
|
|
59d4e4f61d | ||
|
|
44f776921c | ||
|
|
871ee2b9b0 | ||
|
|
6c048ab4da |
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
||||
@@ -19,3 +19,4 @@ Common Optimizers
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
Muon
|
||||
|
||||
@@ -848,6 +848,107 @@ class Adafactor(Optimizer):
|
||||
return parameter - update
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
r"""The Muon optimizer.
|
||||
|
||||
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||
|
||||
Note:
|
||||
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||
by a different method (e.g., :class:`AdamW`).
|
||||
- For 4D convolutional filters, it works by flattening their last
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
learning_rate (float or callable): The learning rate.
|
||||
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||
Default: ``0.01``
|
||||
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||
better performance. Default: ``True``
|
||||
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||
orthogonalization. Default: ``5``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
momentum: float = 0.95,
|
||||
weight_decay: float = 0.01,
|
||||
nesterov: bool = True,
|
||||
ns_steps: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.nesterov = nesterov
|
||||
self.ns_steps = ns_steps
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||
assert (
|
||||
X.ndim == 2
|
||||
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
|
||||
X = X / norm
|
||||
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
return X
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Muon parameter update"""
|
||||
|
||||
if self.weight_decay != 0:
|
||||
gradient = gradient + self.weight_decay * parameter
|
||||
|
||||
v = self.momentum * state["v"]
|
||||
v = v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
|
||||
if self.nesterov:
|
||||
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
update = v
|
||||
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
|
||||
if update.ndim >= 2:
|
||||
original_shape = update.shape
|
||||
reshape_needed = update.ndim > 2
|
||||
|
||||
if reshape_needed:
|
||||
update = mx.reshape(update, (update.shape[0], -1))
|
||||
|
||||
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||
|
||||
if reshape_needed:
|
||||
update = mx.reshape(update, original_shape)
|
||||
|
||||
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||
|
||||
return parameter - lr * update
|
||||
|
||||
|
||||
def clip_grad_norm(grads, max_norm):
|
||||
"""Clips the global norm of the gradients.
|
||||
|
||||
|
||||
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(xp["x"].shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_muon(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||
"second": mx.zeros((3, 3)),
|
||||
"conv": mx.zeros((16, 8, 3, 3)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Test update
|
||||
updated_params = optim.apply_gradients(grads, params)
|
||||
|
||||
# Check that shapes are preserved
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, u: p.shape == u.shape,
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Check that parameters actually changed
|
||||
self.assertFalse(
|
||||
tree_equal(
|
||||
lambda p, u: mx.array_equal(p, u),
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Test with different configurations
|
||||
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||
optim_no_nesterov.apply_gradients(grads, params)
|
||||
|
||||
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||
optim_no_momentum.apply_gradients(grads, params)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
model = nn.Linear(10, 10)
|
||||
x = mx.random.uniform(shape=(2, 10))
|
||||
|
||||
Reference in New Issue
Block a user