Adding support for the Muon Optimizer (#1914)

* initial commit with workong optmimizer

* update ACKNOWLEDGMENTS.md

* nits and adding it to test

* nits

* G.astype(mx.bfloat16) to G.astype(G.dtype)

* G.ndim >= 2 to assert G.ndim == 2

* remove coments

* replace with  mx.addmm

* remove comments

* format

* nits

* match muon

* fix addmm

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Gökdeniz Gülmez 2025-07-18 21:25:28 +02:00 committed by GitHub
parent 45adec102c
commit deee214a95
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 184 additions and 7 deletions

View File

@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
- Paul Paczuski: Improved stability of BCE loss calculation - Paul Paczuski: Improved stability of BCE loss calculation
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
<a href="https://github.com/ml-explore/mlx/graphs/contributors"> <a href="https://github.com/ml-explore/mlx/graphs/contributors">
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" /> <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />

View File

@ -19,3 +19,4 @@ Common Optimizers
Adamax Adamax
Lion Lion
MultiOptimizer MultiOptimizer
Muon

View File

@ -119,7 +119,6 @@ class MatMul {
uint64_t b_rows, uint64_t b_rows,
uint64_t b_cols, uint64_t b_cols,
int64_t ldb, int64_t ldb,
bool c_transposed,
int64_t ldc, int64_t ldc,
int32_t batch_count, int32_t batch_count,
int64_t a_batch_stride, int64_t a_batch_stride,
@ -141,7 +140,7 @@ class MatMul {
b_batch_stride) { b_batch_stride) {
auto type = dtype_to_cuda_type(dtype); auto type = dtype_to_cuda_type(dtype);
c_desc_ = create_matrix_layout( c_desc_ = create_matrix_layout(
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride); type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
} }
~MatMul() { ~MatMul() {
@ -403,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
assert(inputs.size() == 3); assert(inputs.size() == 3);
auto& a_pre = inputs[0]; auto& a_pre = inputs[0];
auto& b_pre = inputs[1]; auto& b_pre = inputs[1];
auto& c_pre = inputs[2]; auto c = inputs[2];
out.set_data(allocator::malloc(out.nbytes()));
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Init checks and prep // Init checks and prep
@ -418,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
// the arrays // the arrays
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
int64_t ldc;
{
auto stx = c.strides()[c.ndim() - 2];
auto sty = c.strides()[c.ndim() - 1];
if (sty == 1 && stx == c.shape(-1)) {
ldc = stx;
out.set_data(allocator::malloc(out.nbytes()));
} else if (sty == 1 && stx == 0) {
ldc = 0;
out.set_data(allocator::malloc(out.nbytes()));
} else {
// Copy C into out and set C to out
ldc = c.shape(-1);
copy_gpu(c, out, CopyType::General, s);
c = out;
}
}
///////////////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////////////////
// Check and collapse batch dimensions // Check and collapse batch dimensions
@ -456,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
K, K,
N, N,
ldb, ldb,
c_transposed,
ldc, ldc,
batch_shape.back(), batch_shape.back(),
a_batch_strides.back(), a_batch_strides.back(),

View File

@ -848,6 +848,106 @@ class Adafactor(Optimizer):
return parameter - update return parameter - update
class Muon(Optimizer):
r"""The Muon optimizer.
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
original implementation: `Muon: An optimizer for hidden layers in neural
networks <https://kellerjordan.github.io/posts/muon/>`_
Note:
- Muon may be sub-optimal for the embedding layer, the final fully
connected layer, or any 0D/1D parameters. Those should be optimized
by a different method (e.g., :class:`AdamW`).
- For 4D convolutional filters, it works by flattening their last
dimensions.
Args:
learning_rate (float or callable): The learning rate.
momentum (float, optional): The momentum strength. Default: ``0.95``
weight_decay (float, optional): The weight decay (L2 penalty).
Default: ``0.01``
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
better performance. Default: ``True``
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
orthogonalization. Default: ``5``
"""
def __init__(
self,
learning_rate: Union[float, Callable[[mx.array], mx.array]],
momentum: float = 0.95,
weight_decay: float = 0.01,
nesterov: bool = True,
ns_steps: int = 5,
):
super().__init__()
self._maybe_schedule("learning_rate", learning_rate)
self.momentum = momentum
self.weight_decay = weight_decay
self.nesterov = nesterov
self.ns_steps = ns_steps
def init_single(self, parameter: mx.array, state: dict):
"""Initialize optimizer state"""
state["v"] = mx.zeros_like(parameter)
def _zeropower_via_newtonschulz5(self, X, steps: int):
assert (
X.ndim == 2
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
a, b, c = (3.4445, -4.7750, 2.0315)
transpose_needed = X.shape[-2] > X.shape[-1]
if transpose_needed:
X = X.T
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
for _ in range(steps):
A = X @ X.T
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
if transpose_needed:
X = X.T
return X
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
"""Performs the Muon parameter update"""
if self.weight_decay != 0:
gradient = gradient + self.weight_decay * parameter
v = self.momentum * state["v"]
v = v + (1 - self.momentum) * gradient
state["v"] = v
if self.nesterov:
update = gradient * (1 - self.momentum) + v * self.momentum
else:
update = v
lr = self.learning_rate.astype(gradient.dtype)
if update.ndim >= 2:
original_shape = update.shape
reshape_needed = update.ndim > 2
if reshape_needed:
update = mx.reshape(update, (update.shape[0], -1))
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
if reshape_needed:
update = mx.reshape(update, original_shape)
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
return parameter - lr * update
def clip_grad_norm(grads, max_norm): def clip_grad_norm(grads, max_norm):
"""Clips the global norm of the gradients. """Clips the global norm of the gradients.

View File

@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
# Transposed c
a = mx.ones((10, 5)).T
b = mx.ones((5, 5))
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
expected = 1.5 * a + 0.5 * (b @ a)
self.assertTrue(mx.allclose(expected, out))
# Broadcast c
a = mx.ones((5, 5))
b = mx.ones((5, 5))
c = mx.ones((1, 5))
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
expected = 1.5 * c + 0.5 * (a @ b)
self.assertTrue(mx.allclose(expected, out))
def test_addmm_grad(self): def test_addmm_grad(self):
def make_ref_addmm(alpha, beta): def make_ref_addmm(alpha, beta):
return lambda c, a, b: alpha * (a @ b) + beta * c return lambda c, a, b: alpha * (a @ b) + beta * c

View File

@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
self.assertEqual(xp["x"].shape, x.shape) self.assertEqual(xp["x"].shape, x.shape)
self.assertEqual(optimizer.state["step"], 2) self.assertEqual(optimizer.state["step"], 2)
def test_muon(self):
params = {
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
"second": mx.zeros((3, 3)),
"conv": mx.zeros((16, 8, 3, 3)),
}
grads = tree_map(lambda x: mx.ones_like(x), params)
# Explicit init
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
optim.init(params)
self.assertTrue(
tree_equal(
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
params,
optim.state,
)
)
# Test update
updated_params = optim.apply_gradients(grads, params)
# Check that shapes are preserved
self.assertTrue(
tree_equal(
lambda p, u: p.shape == u.shape,
params,
updated_params,
)
)
# Check that parameters actually changed
self.assertFalse(
tree_equal(
lambda p, u: mx.array_equal(p, u),
params,
updated_params,
)
)
# Test with different configurations
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
optim_no_nesterov.apply_gradients(grads, params)
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
optim_no_momentum.apply_gradients(grads, params)
def test_compiled_optimizer(self): def test_compiled_optimizer(self):
model = nn.Linear(10, 10) model = nn.Linear(10, 10)
x = mx.random.uniform(shape=(2, 10)) x = mx.random.uniform(shape=(2, 10))