mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Compare commits
22 Commits
simple-gem
...
0a8bb904d7
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a8bb904d7 | ||
|
|
c535d8c1b5 | ||
|
|
4b3d7634cd | ||
|
|
516d172ba5 | ||
|
|
698daee214 | ||
|
|
4c0f7c713b | ||
|
|
3889c805da | ||
|
|
060404d862 | ||
|
|
7f39e9c299 | ||
|
|
baad6e392b | ||
|
|
784e0716fe | ||
|
|
df6d9e972f | ||
|
|
650c956fe6 | ||
|
|
d3d575cce7 | ||
|
|
8f2744dcf3 | ||
|
|
b12be4b7e0 | ||
|
|
ebfcb4a14f | ||
|
|
79175a1f35 | ||
|
|
59d4e4f61d | ||
|
|
44f776921c | ||
|
|
871ee2b9b0 | ||
|
|
6c048ab4da |
@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
|||||||
@@ -19,3 +19,4 @@ Common Optimizers
|
|||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
MultiOptimizer
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
|||||||
@@ -848,6 +848,107 @@ class Adafactor(Optimizer):
|
|||||||
return parameter - update
|
return parameter - update
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(Optimizer):
|
||||||
|
r"""The Muon optimizer.
|
||||||
|
|
||||||
|
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||||
|
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||||
|
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||||
|
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||||
|
by a different method (e.g., :class:`AdamW`).
|
||||||
|
- For 4D convolutional filters, it works by flattening their last
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float or callable): The learning rate.
|
||||||
|
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||||
|
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||||
|
Default: ``0.01``
|
||||||
|
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||||
|
better performance. Default: ``True``
|
||||||
|
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||||
|
orthogonalization. Default: ``5``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
momentum: float = 0.95,
|
||||||
|
weight_decay: float = 0.01,
|
||||||
|
nesterov: bool = True,
|
||||||
|
ns_steps: int = 5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
|
self.momentum = momentum
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.nesterov = nesterov
|
||||||
|
self.ns_steps = ns_steps
|
||||||
|
|
||||||
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
|
"""Initialize optimizer state"""
|
||||||
|
state["v"] = mx.zeros_like(parameter)
|
||||||
|
|
||||||
|
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||||
|
assert (
|
||||||
|
X.ndim == 2
|
||||||
|
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||||
|
|
||||||
|
if transpose_needed:
|
||||||
|
X = X.T
|
||||||
|
|
||||||
|
norm = mx.sqrt(mx.sum(mx.square(X), keepdims=True) + 1e-7)
|
||||||
|
X = X / norm
|
||||||
|
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ X.T
|
||||||
|
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||||
|
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||||
|
|
||||||
|
if transpose_needed:
|
||||||
|
X = X.T
|
||||||
|
return X
|
||||||
|
|
||||||
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
|
"""Performs the Muon parameter update"""
|
||||||
|
|
||||||
|
if self.weight_decay != 0:
|
||||||
|
gradient = gradient + self.weight_decay * parameter
|
||||||
|
|
||||||
|
v = self.momentum * state["v"]
|
||||||
|
v = v + (1 - self.momentum) * gradient
|
||||||
|
state["v"] = v
|
||||||
|
|
||||||
|
if self.nesterov:
|
||||||
|
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||||
|
else:
|
||||||
|
update = v
|
||||||
|
|
||||||
|
lr = self.learning_rate.astype(gradient.dtype)
|
||||||
|
|
||||||
|
if update.ndim >= 2:
|
||||||
|
original_shape = update.shape
|
||||||
|
reshape_needed = update.ndim > 2
|
||||||
|
|
||||||
|
if reshape_needed:
|
||||||
|
update = mx.reshape(update, (update.shape[0], -1))
|
||||||
|
|
||||||
|
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||||
|
|
||||||
|
if reshape_needed:
|
||||||
|
update = mx.reshape(update, original_shape)
|
||||||
|
|
||||||
|
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||||
|
|
||||||
|
return parameter - lr * update
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_norm(grads, max_norm):
|
def clip_grad_norm(grads, max_norm):
|
||||||
"""Clips the global norm of the gradients.
|
"""Clips the global norm of the gradients.
|
||||||
|
|
||||||
|
|||||||
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(xp["x"].shape, x.shape)
|
self.assertEqual(xp["x"].shape, x.shape)
|
||||||
self.assertEqual(optimizer.state["step"], 2)
|
self.assertEqual(optimizer.state["step"], 2)
|
||||||
|
|
||||||
|
def test_muon(self):
|
||||||
|
params = {
|
||||||
|
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||||
|
"second": mx.zeros((3, 3)),
|
||||||
|
"conv": mx.zeros((16, 8, 3, 3)),
|
||||||
|
}
|
||||||
|
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||||
|
|
||||||
|
# Explicit init
|
||||||
|
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||||
|
optim.init(params)
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||||
|
params,
|
||||||
|
optim.state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test update
|
||||||
|
updated_params = optim.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
# Check that shapes are preserved
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, u: p.shape == u.shape,
|
||||||
|
params,
|
||||||
|
updated_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that parameters actually changed
|
||||||
|
self.assertFalse(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, u: mx.array_equal(p, u),
|
||||||
|
params,
|
||||||
|
updated_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with different configurations
|
||||||
|
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||||
|
optim_no_nesterov.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||||
|
optim_no_momentum.apply_gradients(grads, params)
|
||||||
|
|
||||||
def test_compiled_optimizer(self):
|
def test_compiled_optimizer(self):
|
||||||
model = nn.Linear(10, 10)
|
model = nn.Linear(10, 10)
|
||||||
x = mx.random.uniform(shape=(2, 10))
|
x = mx.random.uniform(shape=(2, 10))
|
||||||
|
|||||||
Reference in New Issue
Block a user