mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
 | 
				
			|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
 | 
					- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
 | 
				
			||||||
- Paul Paczuski: Improved stability of BCE loss calculation
 | 
					- Paul Paczuski: Improved stability of BCE loss calculation
 | 
				
			||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
 | 
					- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
 | 
				
			||||||
 | 
					- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
 | 
					<a href="https://github.com/ml-explore/mlx/graphs/contributors">
 | 
				
			||||||
  <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
 | 
					  <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -19,3 +19,4 @@ Common Optimizers
 | 
				
			|||||||
   Adamax
 | 
					   Adamax
 | 
				
			||||||
   Lion
 | 
					   Lion
 | 
				
			||||||
   MultiOptimizer
 | 
					   MultiOptimizer
 | 
				
			||||||
 | 
					   Muon
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -119,7 +119,6 @@ class MatMul {
 | 
				
			|||||||
      uint64_t b_rows,
 | 
					      uint64_t b_rows,
 | 
				
			||||||
      uint64_t b_cols,
 | 
					      uint64_t b_cols,
 | 
				
			||||||
      int64_t ldb,
 | 
					      int64_t ldb,
 | 
				
			||||||
      bool c_transposed,
 | 
					 | 
				
			||||||
      int64_t ldc,
 | 
					      int64_t ldc,
 | 
				
			||||||
      int32_t batch_count,
 | 
					      int32_t batch_count,
 | 
				
			||||||
      int64_t a_batch_stride,
 | 
					      int64_t a_batch_stride,
 | 
				
			||||||
@@ -141,7 +140,7 @@ class MatMul {
 | 
				
			|||||||
            b_batch_stride) {
 | 
					            b_batch_stride) {
 | 
				
			||||||
    auto type = dtype_to_cuda_type(dtype);
 | 
					    auto type = dtype_to_cuda_type(dtype);
 | 
				
			||||||
    c_desc_ = create_matrix_layout(
 | 
					    c_desc_ = create_matrix_layout(
 | 
				
			||||||
        type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
 | 
					        type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
 | 
				
			||||||
  }
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  ~MatMul() {
 | 
					  ~MatMul() {
 | 
				
			||||||
@@ -403,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  assert(inputs.size() == 3);
 | 
					  assert(inputs.size() == 3);
 | 
				
			||||||
  auto& a_pre = inputs[0];
 | 
					  auto& a_pre = inputs[0];
 | 
				
			||||||
  auto& b_pre = inputs[1];
 | 
					  auto& b_pre = inputs[1];
 | 
				
			||||||
  auto& c_pre = inputs[2];
 | 
					  auto c = inputs[2];
 | 
				
			||||||
 | 
					 | 
				
			||||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /////////////////////////////////////////////////////////////////////////////
 | 
					  /////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
  // Init checks and prep
 | 
					  // Init checks and prep
 | 
				
			||||||
@@ -418,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
  // the arrays
 | 
					  // the arrays
 | 
				
			||||||
  auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
 | 
					  auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
 | 
				
			||||||
  auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
 | 
					  auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
 | 
				
			||||||
  auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
 | 
					
 | 
				
			||||||
 | 
					  int64_t ldc;
 | 
				
			||||||
 | 
					  {
 | 
				
			||||||
 | 
					    auto stx = c.strides()[c.ndim() - 2];
 | 
				
			||||||
 | 
					    auto sty = c.strides()[c.ndim() - 1];
 | 
				
			||||||
 | 
					    if (sty == 1 && stx == c.shape(-1)) {
 | 
				
			||||||
 | 
					      ldc = stx;
 | 
				
			||||||
 | 
					      out.set_data(allocator::malloc(out.nbytes()));
 | 
				
			||||||
 | 
					    } else if (sty == 1 && stx == 0) {
 | 
				
			||||||
 | 
					      ldc = 0;
 | 
				
			||||||
 | 
					      out.set_data(allocator::malloc(out.nbytes()));
 | 
				
			||||||
 | 
					    } else {
 | 
				
			||||||
 | 
					      // Copy C into out and set C to out
 | 
				
			||||||
 | 
					      ldc = c.shape(-1);
 | 
				
			||||||
 | 
					      copy_gpu(c, out, CopyType::General, s);
 | 
				
			||||||
 | 
					      c = out;
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					  }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  /////////////////////////////////////////////////////////////////////////////
 | 
					  /////////////////////////////////////////////////////////////////////////////
 | 
				
			||||||
  // Check and collapse batch dimensions
 | 
					  // Check and collapse batch dimensions
 | 
				
			||||||
@@ -456,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
				
			|||||||
      K,
 | 
					      K,
 | 
				
			||||||
      N,
 | 
					      N,
 | 
				
			||||||
      ldb,
 | 
					      ldb,
 | 
				
			||||||
      c_transposed,
 | 
					 | 
				
			||||||
      ldc,
 | 
					      ldc,
 | 
				
			||||||
      batch_shape.back(),
 | 
					      batch_shape.back(),
 | 
				
			||||||
      a_batch_strides.back(),
 | 
					      a_batch_strides.back(),
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -848,6 +848,106 @@ class Adafactor(Optimizer):
 | 
				
			|||||||
        return parameter - update
 | 
					        return parameter - update
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class Muon(Optimizer):
 | 
				
			||||||
 | 
					    r"""The Muon optimizer.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
 | 
				
			||||||
 | 
					    original implementation: `Muon: An optimizer for hidden layers in neural
 | 
				
			||||||
 | 
					    networks <https://kellerjordan.github.io/posts/muon/>`_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Note:
 | 
				
			||||||
 | 
					        - Muon may be sub-optimal for the embedding layer, the final fully
 | 
				
			||||||
 | 
					          connected layer, or any 0D/1D parameters. Those should be optimized
 | 
				
			||||||
 | 
					          by a different method (e.g., :class:`AdamW`).
 | 
				
			||||||
 | 
					        - For 4D convolutional filters, it works by flattening their last
 | 
				
			||||||
 | 
					          dimensions.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Args:
 | 
				
			||||||
 | 
					        learning_rate (float or callable): The learning rate.
 | 
				
			||||||
 | 
					        momentum (float, optional): The momentum strength. Default: ``0.95``
 | 
				
			||||||
 | 
					        weight_decay (float, optional): The weight decay (L2 penalty).
 | 
				
			||||||
 | 
					            Default: ``0.01``
 | 
				
			||||||
 | 
					        nesterov (bool, optional): Enables Nesterov momentum. Recommended for
 | 
				
			||||||
 | 
					            better performance.  Default: ``True``
 | 
				
			||||||
 | 
					        ns_steps (int, optional): Number of Newton-Schulz iteration steps for
 | 
				
			||||||
 | 
					            orthogonalization.  Default: ``5``
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        learning_rate: Union[float, Callable[[mx.array], mx.array]],
 | 
				
			||||||
 | 
					        momentum: float = 0.95,
 | 
				
			||||||
 | 
					        weight_decay: float = 0.01,
 | 
				
			||||||
 | 
					        nesterov: bool = True,
 | 
				
			||||||
 | 
					        ns_steps: int = 5,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
 | 
					        super().__init__()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        self._maybe_schedule("learning_rate", learning_rate)
 | 
				
			||||||
 | 
					        self.momentum = momentum
 | 
				
			||||||
 | 
					        self.weight_decay = weight_decay
 | 
				
			||||||
 | 
					        self.nesterov = nesterov
 | 
				
			||||||
 | 
					        self.ns_steps = ns_steps
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def init_single(self, parameter: mx.array, state: dict):
 | 
				
			||||||
 | 
					        """Initialize optimizer state"""
 | 
				
			||||||
 | 
					        state["v"] = mx.zeros_like(parameter)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def _zeropower_via_newtonschulz5(self, X, steps: int):
 | 
				
			||||||
 | 
					        assert (
 | 
				
			||||||
 | 
					            X.ndim == 2
 | 
				
			||||||
 | 
					        ), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
 | 
				
			||||||
 | 
					        a, b, c = (3.4445, -4.7750, 2.0315)
 | 
				
			||||||
 | 
					        transpose_needed = X.shape[-2] > X.shape[-1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if transpose_needed:
 | 
				
			||||||
 | 
					            X = X.T
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for _ in range(steps):
 | 
				
			||||||
 | 
					            A = X @ X.T
 | 
				
			||||||
 | 
					            B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
 | 
				
			||||||
 | 
					            X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if transpose_needed:
 | 
				
			||||||
 | 
					            X = X.T
 | 
				
			||||||
 | 
					        return X
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
 | 
				
			||||||
 | 
					        """Performs the Muon parameter update"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.weight_decay != 0:
 | 
				
			||||||
 | 
					            gradient = gradient + self.weight_decay * parameter
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        v = self.momentum * state["v"]
 | 
				
			||||||
 | 
					        v = v + (1 - self.momentum) * gradient
 | 
				
			||||||
 | 
					        state["v"] = v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self.nesterov:
 | 
				
			||||||
 | 
					            update = gradient * (1 - self.momentum) + v * self.momentum
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            update = v
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        lr = self.learning_rate.astype(gradient.dtype)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if update.ndim >= 2:
 | 
				
			||||||
 | 
					            original_shape = update.shape
 | 
				
			||||||
 | 
					            reshape_needed = update.ndim > 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if reshape_needed:
 | 
				
			||||||
 | 
					                update = mx.reshape(update, (update.shape[0], -1))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if reshape_needed:
 | 
				
			||||||
 | 
					                update = mx.reshape(update, original_shape)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        return parameter - lr * update
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def clip_grad_norm(grads, max_norm):
 | 
					def clip_grad_norm(grads, max_norm):
 | 
				
			||||||
    """Clips the global norm of the gradients.
 | 
					    """Clips the global norm of the gradients.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
 | 
				
			|||||||
            self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
 | 
					            self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
 | 
				
			||||||
            self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
 | 
					            self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Transposed c
 | 
				
			||||||
 | 
					        a = mx.ones((10, 5)).T
 | 
				
			||||||
 | 
					        b = mx.ones((5, 5))
 | 
				
			||||||
 | 
					        out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
 | 
				
			||||||
 | 
					        expected = 1.5 * a + 0.5 * (b @ a)
 | 
				
			||||||
 | 
					        self.assertTrue(mx.allclose(expected, out))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Broadcast c
 | 
				
			||||||
 | 
					        a = mx.ones((5, 5))
 | 
				
			||||||
 | 
					        b = mx.ones((5, 5))
 | 
				
			||||||
 | 
					        c = mx.ones((1, 5))
 | 
				
			||||||
 | 
					        out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
 | 
				
			||||||
 | 
					        expected = 1.5 * c + 0.5 * (a @ b)
 | 
				
			||||||
 | 
					        self.assertTrue(mx.allclose(expected, out))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_addmm_grad(self):
 | 
					    def test_addmm_grad(self):
 | 
				
			||||||
        def make_ref_addmm(alpha, beta):
 | 
					        def make_ref_addmm(alpha, beta):
 | 
				
			||||||
            return lambda c, a, b: alpha * (a @ b) + beta * c
 | 
					            return lambda c, a, b: alpha * (a @ b) + beta * c
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
 | 
				
			|||||||
            self.assertEqual(xp["x"].shape, x.shape)
 | 
					            self.assertEqual(xp["x"].shape, x.shape)
 | 
				
			||||||
        self.assertEqual(optimizer.state["step"], 2)
 | 
					        self.assertEqual(optimizer.state["step"], 2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def test_muon(self):
 | 
				
			||||||
 | 
					        params = {
 | 
				
			||||||
 | 
					            "first": [mx.zeros((10, 5)), mx.zeros((1,))],
 | 
				
			||||||
 | 
					            "second": mx.zeros((3, 3)),
 | 
				
			||||||
 | 
					            "conv": mx.zeros((16, 8, 3, 3)),
 | 
				
			||||||
 | 
					        }
 | 
				
			||||||
 | 
					        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Explicit init
 | 
				
			||||||
 | 
					        optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
 | 
				
			||||||
 | 
					        optim.init(params)
 | 
				
			||||||
 | 
					        self.assertTrue(
 | 
				
			||||||
 | 
					            tree_equal(
 | 
				
			||||||
 | 
					                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
				
			||||||
 | 
					                params,
 | 
				
			||||||
 | 
					                optim.state,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Test update
 | 
				
			||||||
 | 
					        updated_params = optim.apply_gradients(grads, params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Check that shapes are preserved
 | 
				
			||||||
 | 
					        self.assertTrue(
 | 
				
			||||||
 | 
					            tree_equal(
 | 
				
			||||||
 | 
					                lambda p, u: p.shape == u.shape,
 | 
				
			||||||
 | 
					                params,
 | 
				
			||||||
 | 
					                updated_params,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Check that parameters actually changed
 | 
				
			||||||
 | 
					        self.assertFalse(
 | 
				
			||||||
 | 
					            tree_equal(
 | 
				
			||||||
 | 
					                lambda p, u: mx.array_equal(p, u),
 | 
				
			||||||
 | 
					                params,
 | 
				
			||||||
 | 
					                updated_params,
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Test with different configurations
 | 
				
			||||||
 | 
					        optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
 | 
				
			||||||
 | 
					        optim_no_nesterov.apply_gradients(grads, params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
 | 
				
			||||||
 | 
					        optim_no_momentum.apply_gradients(grads, params)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def test_compiled_optimizer(self):
 | 
					    def test_compiled_optimizer(self):
 | 
				
			||||||
        model = nn.Linear(10, 10)
 | 
					        model = nn.Linear(10, 10)
 | 
				
			||||||
        x = mx.random.uniform(shape=(2, 10))
 | 
					        x = mx.random.uniform(shape=(2, 10))
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user