From deee214a95ed0be4576f765172f153592509a2a3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Fri, 18 Jul 2025 21:25:28 +0200 Subject: [PATCH] Adding support for the Muon Optimizer (#1914) * initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun --- ACKNOWLEDGMENTS.md | 1 + .../python/optimizers/common_optimizers.rst | 1 + mlx/backend/cuda/matmul.cpp | 27 +++-- python/mlx/optimizers/optimizers.py | 100 ++++++++++++++++++ python/tests/test_blas.py | 15 +++ python/tests/test_optimizers.py | 47 ++++++++ 6 files changed, 184 insertions(+), 7 deletions(-) diff --git a/ACKNOWLEDGMENTS.md b/ACKNOWLEDGMENTS.md index 4b0cea123..786c9042c 100644 --- a/ACKNOWLEDGMENTS.md +++ b/ACKNOWLEDGMENTS.md @@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals: - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions. - Paul Paczuski: Improved stability of BCE loss calculation - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops. +- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer. diff --git a/docs/src/python/optimizers/common_optimizers.rst b/docs/src/python/optimizers/common_optimizers.rst index 86f800135..4975df541 100644 --- a/docs/src/python/optimizers/common_optimizers.rst +++ b/docs/src/python/optimizers/common_optimizers.rst @@ -19,3 +19,4 @@ Common Optimizers Adamax Lion MultiOptimizer + Muon diff --git a/mlx/backend/cuda/matmul.cpp b/mlx/backend/cuda/matmul.cpp index 4110e7eff..c50fe7fee 100644 --- a/mlx/backend/cuda/matmul.cpp +++ b/mlx/backend/cuda/matmul.cpp @@ -119,7 +119,6 @@ class MatMul { uint64_t b_rows, uint64_t b_cols, int64_t ldb, - bool c_transposed, int64_t ldc, int32_t batch_count, int64_t a_batch_stride, @@ -141,7 +140,7 @@ class MatMul { b_batch_stride) { auto type = dtype_to_cuda_type(dtype); c_desc_ = create_matrix_layout( - type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride); + type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride); } ~MatMul() { @@ -403,9 +402,7 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { assert(inputs.size() == 3); auto& a_pre = inputs[0]; auto& b_pre = inputs[1]; - auto& c_pre = inputs[2]; - - out.set_data(allocator::malloc(out.nbytes())); + auto c = inputs[2]; ///////////////////////////////////////////////////////////////////////////// // Init checks and prep @@ -418,7 +415,24 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { // the arrays auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre); auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre); - auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre); + + int64_t ldc; + { + auto stx = c.strides()[c.ndim() - 2]; + auto sty = c.strides()[c.ndim() - 1]; + if (sty == 1 && stx == c.shape(-1)) { + ldc = stx; + out.set_data(allocator::malloc(out.nbytes())); + } else if (sty == 1 && stx == 0) { + ldc = 0; + out.set_data(allocator::malloc(out.nbytes())); + } else { + // Copy C into out and set C to out + ldc = c.shape(-1); + copy_gpu(c, out, CopyType::General, s); + c = out; + } + } ///////////////////////////////////////////////////////////////////////////// // Check and collapse batch dimensions @@ -456,7 +470,6 @@ void AddMM::eval_gpu(const std::vector& inputs, array& out) { K, N, ldb, - c_transposed, ldc, batch_shape.back(), a_batch_strides.back(), diff --git a/python/mlx/optimizers/optimizers.py b/python/mlx/optimizers/optimizers.py index 26b732ebd..07b68cc5b 100644 --- a/python/mlx/optimizers/optimizers.py +++ b/python/mlx/optimizers/optimizers.py @@ -848,6 +848,106 @@ class Adafactor(Optimizer): return parameter - update +class Muon(Optimizer): + r"""The Muon optimizer. + + Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the + original implementation: `Muon: An optimizer for hidden layers in neural + networks `_ + + Note: + - Muon may be sub-optimal for the embedding layer, the final fully + connected layer, or any 0D/1D parameters. Those should be optimized + by a different method (e.g., :class:`AdamW`). + - For 4D convolutional filters, it works by flattening their last + dimensions. + + Args: + learning_rate (float or callable): The learning rate. + momentum (float, optional): The momentum strength. Default: ``0.95`` + weight_decay (float, optional): The weight decay (L2 penalty). + Default: ``0.01`` + nesterov (bool, optional): Enables Nesterov momentum. Recommended for + better performance. Default: ``True`` + ns_steps (int, optional): Number of Newton-Schulz iteration steps for + orthogonalization. Default: ``5`` + """ + + def __init__( + self, + learning_rate: Union[float, Callable[[mx.array], mx.array]], + momentum: float = 0.95, + weight_decay: float = 0.01, + nesterov: bool = True, + ns_steps: int = 5, + ): + super().__init__() + + self._maybe_schedule("learning_rate", learning_rate) + self.momentum = momentum + self.weight_decay = weight_decay + self.nesterov = nesterov + self.ns_steps = ns_steps + + def init_single(self, parameter: mx.array, state: dict): + """Initialize optimizer state""" + state["v"] = mx.zeros_like(parameter) + + def _zeropower_via_newtonschulz5(self, X, steps: int): + assert ( + X.ndim == 2 + ), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead." + a, b, c = (3.4445, -4.7750, 2.0315) + transpose_needed = X.shape[-2] > X.shape[-1] + + if transpose_needed: + X = X.T + + X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7) + + for _ in range(steps): + A = X @ X.T + B = mx.addmm(b * A, A, A, beta=1.0, alpha=c) + X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0) + + if transpose_needed: + X = X.T + return X + + def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict): + """Performs the Muon parameter update""" + + if self.weight_decay != 0: + gradient = gradient + self.weight_decay * parameter + + v = self.momentum * state["v"] + v = v + (1 - self.momentum) * gradient + state["v"] = v + + if self.nesterov: + update = gradient * (1 - self.momentum) + v * self.momentum + else: + update = v + + lr = self.learning_rate.astype(gradient.dtype) + + if update.ndim >= 2: + original_shape = update.shape + reshape_needed = update.ndim > 2 + + if reshape_needed: + update = mx.reshape(update, (update.shape[0], -1)) + + update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps) + + if reshape_needed: + update = mx.reshape(update, original_shape) + + lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5 + + return parameter - lr * update + + def clip_grad_norm(grads, max_norm): """Clips the global norm of the gradients. diff --git a/python/tests/test_blas.py b/python/tests/test_blas.py index 5e096d9c5..7cc39f06a 100644 --- a/python/tests/test_blas.py +++ b/python/tests/test_blas.py @@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase): self.assertListEqual(list(d_npy.shape), list(d_mlx.shape)) self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5)) + # Transposed c + a = mx.ones((10, 5)).T + b = mx.ones((5, 5)) + out = mx.addmm(a, b, a, beta=1.5, alpha=0.5) + expected = 1.5 * a + 0.5 * (b @ a) + self.assertTrue(mx.allclose(expected, out)) + + # Broadcast c + a = mx.ones((5, 5)) + b = mx.ones((5, 5)) + c = mx.ones((1, 5)) + out = mx.addmm(c, a, b, beta=1.5, alpha=0.5) + expected = 1.5 * c + 0.5 * (a @ b) + self.assertTrue(mx.allclose(expected, out)) + def test_addmm_grad(self): def make_ref_addmm(alpha, beta): return lambda c, a, b: alpha * (a @ b) + beta * c diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 8f9e33679..6869ac357 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase): self.assertEqual(xp["x"].shape, x.shape) self.assertEqual(optimizer.state["step"], 2) + def test_muon(self): + params = { + "first": [mx.zeros((10, 5)), mx.zeros((1,))], + "second": mx.zeros((3, 3)), + "conv": mx.zeros((16, 8, 3, 3)), + } + grads = tree_map(lambda x: mx.ones_like(x), params) + + # Explicit init + optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True) + optim.init(params) + self.assertTrue( + tree_equal( + lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)), + params, + optim.state, + ) + ) + + # Test update + updated_params = optim.apply_gradients(grads, params) + + # Check that shapes are preserved + self.assertTrue( + tree_equal( + lambda p, u: p.shape == u.shape, + params, + updated_params, + ) + ) + + # Check that parameters actually changed + self.assertFalse( + tree_equal( + lambda p, u: mx.array_equal(p, u), + params, + updated_params, + ) + ) + + # Test with different configurations + optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False) + optim_no_nesterov.apply_gradients(grads, params) + + optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0) + optim_no_momentum.apply_gradients(grads, params) + def test_compiled_optimizer(self): model = nn.Linear(10, 10) x = mx.random.uniform(shape=(2, 10))