mirror of
				https://github.com/ml-explore/mlx.git
				synced 2025-11-04 10:38:10 +08:00 
			
		
		
		
	Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
		@@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
 | 
			
		||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
 | 
			
		||||
- Paul Paczuski: Improved stability of BCE loss calculation
 | 
			
		||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
 | 
			
		||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
 | 
			
		||||
 | 
			
		||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
 | 
			
		||||
  <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
 | 
			
		||||
 
 | 
			
		||||
@@ -19,3 +19,4 @@ Common Optimizers
 | 
			
		||||
   Adamax
 | 
			
		||||
   Lion
 | 
			
		||||
   MultiOptimizer
 | 
			
		||||
   Muon
 | 
			
		||||
 
 | 
			
		||||
@@ -119,7 +119,6 @@ class MatMul {
 | 
			
		||||
      uint64_t b_rows,
 | 
			
		||||
      uint64_t b_cols,
 | 
			
		||||
      int64_t ldb,
 | 
			
		||||
      bool c_transposed,
 | 
			
		||||
      int64_t ldc,
 | 
			
		||||
      int32_t batch_count,
 | 
			
		||||
      int64_t a_batch_stride,
 | 
			
		||||
@@ -141,7 +140,7 @@ class MatMul {
 | 
			
		||||
            b_batch_stride) {
 | 
			
		||||
    auto type = dtype_to_cuda_type(dtype);
 | 
			
		||||
    c_desc_ = create_matrix_layout(
 | 
			
		||||
        type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
 | 
			
		||||
        type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  ~MatMul() {
 | 
			
		||||
@@ -403,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  assert(inputs.size() == 3);
 | 
			
		||||
  auto& a_pre = inputs[0];
 | 
			
		||||
  auto& b_pre = inputs[1];
 | 
			
		||||
  auto& c_pre = inputs[2];
 | 
			
		||||
 | 
			
		||||
  out.set_data(allocator::malloc(out.nbytes()));
 | 
			
		||||
  auto c = inputs[2];
 | 
			
		||||
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  // Init checks and prep
 | 
			
		||||
@@ -418,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
  // the arrays
 | 
			
		||||
  auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
 | 
			
		||||
  auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
 | 
			
		||||
  auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
 | 
			
		||||
 | 
			
		||||
  int64_t ldc;
 | 
			
		||||
  {
 | 
			
		||||
    auto stx = c.strides()[c.ndim() - 2];
 | 
			
		||||
    auto sty = c.strides()[c.ndim() - 1];
 | 
			
		||||
    if (sty == 1 && stx == c.shape(-1)) {
 | 
			
		||||
      ldc = stx;
 | 
			
		||||
      out.set_data(allocator::malloc(out.nbytes()));
 | 
			
		||||
    } else if (sty == 1 && stx == 0) {
 | 
			
		||||
      ldc = 0;
 | 
			
		||||
      out.set_data(allocator::malloc(out.nbytes()));
 | 
			
		||||
    } else {
 | 
			
		||||
      // Copy C into out and set C to out
 | 
			
		||||
      ldc = c.shape(-1);
 | 
			
		||||
      copy_gpu(c, out, CopyType::General, s);
 | 
			
		||||
      c = out;
 | 
			
		||||
    }
 | 
			
		||||
  }
 | 
			
		||||
 | 
			
		||||
  /////////////////////////////////////////////////////////////////////////////
 | 
			
		||||
  // Check and collapse batch dimensions
 | 
			
		||||
@@ -456,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
 | 
			
		||||
      K,
 | 
			
		||||
      N,
 | 
			
		||||
      ldb,
 | 
			
		||||
      c_transposed,
 | 
			
		||||
      ldc,
 | 
			
		||||
      batch_shape.back(),
 | 
			
		||||
      a_batch_strides.back(),
 | 
			
		||||
 
 | 
			
		||||
@@ -848,6 +848,106 @@ class Adafactor(Optimizer):
 | 
			
		||||
        return parameter - update
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Muon(Optimizer):
 | 
			
		||||
    r"""The Muon optimizer.
 | 
			
		||||
 | 
			
		||||
    Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
 | 
			
		||||
    original implementation: `Muon: An optimizer for hidden layers in neural
 | 
			
		||||
    networks <https://kellerjordan.github.io/posts/muon/>`_
 | 
			
		||||
 | 
			
		||||
    Note:
 | 
			
		||||
        - Muon may be sub-optimal for the embedding layer, the final fully
 | 
			
		||||
          connected layer, or any 0D/1D parameters. Those should be optimized
 | 
			
		||||
          by a different method (e.g., :class:`AdamW`).
 | 
			
		||||
        - For 4D convolutional filters, it works by flattening their last
 | 
			
		||||
          dimensions.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
        learning_rate (float or callable): The learning rate.
 | 
			
		||||
        momentum (float, optional): The momentum strength. Default: ``0.95``
 | 
			
		||||
        weight_decay (float, optional): The weight decay (L2 penalty).
 | 
			
		||||
            Default: ``0.01``
 | 
			
		||||
        nesterov (bool, optional): Enables Nesterov momentum. Recommended for
 | 
			
		||||
            better performance.  Default: ``True``
 | 
			
		||||
        ns_steps (int, optional): Number of Newton-Schulz iteration steps for
 | 
			
		||||
            orthogonalization.  Default: ``5``
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        learning_rate: Union[float, Callable[[mx.array], mx.array]],
 | 
			
		||||
        momentum: float = 0.95,
 | 
			
		||||
        weight_decay: float = 0.01,
 | 
			
		||||
        nesterov: bool = True,
 | 
			
		||||
        ns_steps: int = 5,
 | 
			
		||||
    ):
 | 
			
		||||
        super().__init__()
 | 
			
		||||
 | 
			
		||||
        self._maybe_schedule("learning_rate", learning_rate)
 | 
			
		||||
        self.momentum = momentum
 | 
			
		||||
        self.weight_decay = weight_decay
 | 
			
		||||
        self.nesterov = nesterov
 | 
			
		||||
        self.ns_steps = ns_steps
 | 
			
		||||
 | 
			
		||||
    def init_single(self, parameter: mx.array, state: dict):
 | 
			
		||||
        """Initialize optimizer state"""
 | 
			
		||||
        state["v"] = mx.zeros_like(parameter)
 | 
			
		||||
 | 
			
		||||
    def _zeropower_via_newtonschulz5(self, X, steps: int):
 | 
			
		||||
        assert (
 | 
			
		||||
            X.ndim == 2
 | 
			
		||||
        ), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
 | 
			
		||||
        a, b, c = (3.4445, -4.7750, 2.0315)
 | 
			
		||||
        transpose_needed = X.shape[-2] > X.shape[-1]
 | 
			
		||||
 | 
			
		||||
        if transpose_needed:
 | 
			
		||||
            X = X.T
 | 
			
		||||
 | 
			
		||||
        X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
 | 
			
		||||
 | 
			
		||||
        for _ in range(steps):
 | 
			
		||||
            A = X @ X.T
 | 
			
		||||
            B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
 | 
			
		||||
            X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
 | 
			
		||||
 | 
			
		||||
        if transpose_needed:
 | 
			
		||||
            X = X.T
 | 
			
		||||
        return X
 | 
			
		||||
 | 
			
		||||
    def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
 | 
			
		||||
        """Performs the Muon parameter update"""
 | 
			
		||||
 | 
			
		||||
        if self.weight_decay != 0:
 | 
			
		||||
            gradient = gradient + self.weight_decay * parameter
 | 
			
		||||
 | 
			
		||||
        v = self.momentum * state["v"]
 | 
			
		||||
        v = v + (1 - self.momentum) * gradient
 | 
			
		||||
        state["v"] = v
 | 
			
		||||
 | 
			
		||||
        if self.nesterov:
 | 
			
		||||
            update = gradient * (1 - self.momentum) + v * self.momentum
 | 
			
		||||
        else:
 | 
			
		||||
            update = v
 | 
			
		||||
 | 
			
		||||
        lr = self.learning_rate.astype(gradient.dtype)
 | 
			
		||||
 | 
			
		||||
        if update.ndim >= 2:
 | 
			
		||||
            original_shape = update.shape
 | 
			
		||||
            reshape_needed = update.ndim > 2
 | 
			
		||||
 | 
			
		||||
            if reshape_needed:
 | 
			
		||||
                update = mx.reshape(update, (update.shape[0], -1))
 | 
			
		||||
 | 
			
		||||
            update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
 | 
			
		||||
 | 
			
		||||
            if reshape_needed:
 | 
			
		||||
                update = mx.reshape(update, original_shape)
 | 
			
		||||
 | 
			
		||||
            lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
 | 
			
		||||
 | 
			
		||||
        return parameter - lr * update
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def clip_grad_norm(grads, max_norm):
 | 
			
		||||
    """Clips the global norm of the gradients.
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
 | 
			
		||||
            self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
 | 
			
		||||
            self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
 | 
			
		||||
 | 
			
		||||
        # Transposed c
 | 
			
		||||
        a = mx.ones((10, 5)).T
 | 
			
		||||
        b = mx.ones((5, 5))
 | 
			
		||||
        out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
 | 
			
		||||
        expected = 1.5 * a + 0.5 * (b @ a)
 | 
			
		||||
        self.assertTrue(mx.allclose(expected, out))
 | 
			
		||||
 | 
			
		||||
        # Broadcast c
 | 
			
		||||
        a = mx.ones((5, 5))
 | 
			
		||||
        b = mx.ones((5, 5))
 | 
			
		||||
        c = mx.ones((1, 5))
 | 
			
		||||
        out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
 | 
			
		||||
        expected = 1.5 * c + 0.5 * (a @ b)
 | 
			
		||||
        self.assertTrue(mx.allclose(expected, out))
 | 
			
		||||
 | 
			
		||||
    def test_addmm_grad(self):
 | 
			
		||||
        def make_ref_addmm(alpha, beta):
 | 
			
		||||
            return lambda c, a, b: alpha * (a @ b) + beta * c
 | 
			
		||||
 
 | 
			
		||||
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
 | 
			
		||||
            self.assertEqual(xp["x"].shape, x.shape)
 | 
			
		||||
        self.assertEqual(optimizer.state["step"], 2)
 | 
			
		||||
 | 
			
		||||
    def test_muon(self):
 | 
			
		||||
        params = {
 | 
			
		||||
            "first": [mx.zeros((10, 5)), mx.zeros((1,))],
 | 
			
		||||
            "second": mx.zeros((3, 3)),
 | 
			
		||||
            "conv": mx.zeros((16, 8, 3, 3)),
 | 
			
		||||
        }
 | 
			
		||||
        grads = tree_map(lambda x: mx.ones_like(x), params)
 | 
			
		||||
 | 
			
		||||
        # Explicit init
 | 
			
		||||
        optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
 | 
			
		||||
        optim.init(params)
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            tree_equal(
 | 
			
		||||
                lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
 | 
			
		||||
                params,
 | 
			
		||||
                optim.state,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test update
 | 
			
		||||
        updated_params = optim.apply_gradients(grads, params)
 | 
			
		||||
 | 
			
		||||
        # Check that shapes are preserved
 | 
			
		||||
        self.assertTrue(
 | 
			
		||||
            tree_equal(
 | 
			
		||||
                lambda p, u: p.shape == u.shape,
 | 
			
		||||
                params,
 | 
			
		||||
                updated_params,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Check that parameters actually changed
 | 
			
		||||
        self.assertFalse(
 | 
			
		||||
            tree_equal(
 | 
			
		||||
                lambda p, u: mx.array_equal(p, u),
 | 
			
		||||
                params,
 | 
			
		||||
                updated_params,
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
        # Test with different configurations
 | 
			
		||||
        optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
 | 
			
		||||
        optim_no_nesterov.apply_gradients(grads, params)
 | 
			
		||||
 | 
			
		||||
        optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
 | 
			
		||||
        optim_no_momentum.apply_gradients(grads, params)
 | 
			
		||||
 | 
			
		||||
    def test_compiled_optimizer(self):
 | 
			
		||||
        model = nn.Linear(10, 10)
 | 
			
		||||
        x = mx.random.uniform(shape=(2, 10))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user