mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 16:51:15 +08:00
Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
45adec102c
commit
deee214a95
@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||
|
||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||
|
@ -19,3 +19,4 @@ Common Optimizers
|
||||
Adamax
|
||||
Lion
|
||||
MultiOptimizer
|
||||
Muon
|
||||
|
@ -119,7 +119,6 @@ class MatMul {
|
||||
uint64_t b_rows,
|
||||
uint64_t b_cols,
|
||||
int64_t ldb,
|
||||
bool c_transposed,
|
||||
int64_t ldc,
|
||||
int32_t batch_count,
|
||||
int64_t a_batch_stride,
|
||||
@ -141,7 +140,7 @@ class MatMul {
|
||||
b_batch_stride) {
|
||||
auto type = dtype_to_cuda_type(dtype);
|
||||
c_desc_ = create_matrix_layout(
|
||||
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
||||
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||
}
|
||||
|
||||
~MatMul() {
|
||||
@ -403,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
assert(inputs.size() == 3);
|
||||
auto& a_pre = inputs[0];
|
||||
auto& b_pre = inputs[1];
|
||||
auto& c_pre = inputs[2];
|
||||
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
auto c = inputs[2];
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Init checks and prep
|
||||
@ -418,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
// the arrays
|
||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
|
||||
|
||||
int64_t ldc;
|
||||
{
|
||||
auto stx = c.strides()[c.ndim() - 2];
|
||||
auto sty = c.strides()[c.ndim() - 1];
|
||||
if (sty == 1 && stx == c.shape(-1)) {
|
||||
ldc = stx;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else if (sty == 1 && stx == 0) {
|
||||
ldc = 0;
|
||||
out.set_data(allocator::malloc(out.nbytes()));
|
||||
} else {
|
||||
// Copy C into out and set C to out
|
||||
ldc = c.shape(-1);
|
||||
copy_gpu(c, out, CopyType::General, s);
|
||||
c = out;
|
||||
}
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////////////
|
||||
// Check and collapse batch dimensions
|
||||
@ -456,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
||||
K,
|
||||
N,
|
||||
ldb,
|
||||
c_transposed,
|
||||
ldc,
|
||||
batch_shape.back(),
|
||||
a_batch_strides.back(),
|
||||
|
@ -848,6 +848,106 @@ class Adafactor(Optimizer):
|
||||
return parameter - update
|
||||
|
||||
|
||||
class Muon(Optimizer):
|
||||
r"""The Muon optimizer.
|
||||
|
||||
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||
|
||||
Note:
|
||||
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||
by a different method (e.g., :class:`AdamW`).
|
||||
- For 4D convolutional filters, it works by flattening their last
|
||||
dimensions.
|
||||
|
||||
Args:
|
||||
learning_rate (float or callable): The learning rate.
|
||||
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||
Default: ``0.01``
|
||||
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||
better performance. Default: ``True``
|
||||
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||
orthogonalization. Default: ``5``
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||
momentum: float = 0.95,
|
||||
weight_decay: float = 0.01,
|
||||
nesterov: bool = True,
|
||||
ns_steps: int = 5,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self._maybe_schedule("learning_rate", learning_rate)
|
||||
self.momentum = momentum
|
||||
self.weight_decay = weight_decay
|
||||
self.nesterov = nesterov
|
||||
self.ns_steps = ns_steps
|
||||
|
||||
def init_single(self, parameter: mx.array, state: dict):
|
||||
"""Initialize optimizer state"""
|
||||
state["v"] = mx.zeros_like(parameter)
|
||||
|
||||
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||
assert (
|
||||
X.ndim == 2
|
||||
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
|
||||
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
|
||||
|
||||
for _ in range(steps):
|
||||
A = X @ X.T
|
||||
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||
|
||||
if transpose_needed:
|
||||
X = X.T
|
||||
return X
|
||||
|
||||
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||
"""Performs the Muon parameter update"""
|
||||
|
||||
if self.weight_decay != 0:
|
||||
gradient = gradient + self.weight_decay * parameter
|
||||
|
||||
v = self.momentum * state["v"]
|
||||
v = v + (1 - self.momentum) * gradient
|
||||
state["v"] = v
|
||||
|
||||
if self.nesterov:
|
||||
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||
else:
|
||||
update = v
|
||||
|
||||
lr = self.learning_rate.astype(gradient.dtype)
|
||||
|
||||
if update.ndim >= 2:
|
||||
original_shape = update.shape
|
||||
reshape_needed = update.ndim > 2
|
||||
|
||||
if reshape_needed:
|
||||
update = mx.reshape(update, (update.shape[0], -1))
|
||||
|
||||
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||
|
||||
if reshape_needed:
|
||||
update = mx.reshape(update, original_shape)
|
||||
|
||||
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||
|
||||
return parameter - lr * update
|
||||
|
||||
|
||||
def clip_grad_norm(grads, max_norm):
|
||||
"""Clips the global norm of the gradients.
|
||||
|
||||
|
@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
|
||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||
|
||||
# Transposed c
|
||||
a = mx.ones((10, 5)).T
|
||||
b = mx.ones((5, 5))
|
||||
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * a + 0.5 * (b @ a)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
# Broadcast c
|
||||
a = mx.ones((5, 5))
|
||||
b = mx.ones((5, 5))
|
||||
c = mx.ones((1, 5))
|
||||
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||
expected = 1.5 * c + 0.5 * (a @ b)
|
||||
self.assertTrue(mx.allclose(expected, out))
|
||||
|
||||
def test_addmm_grad(self):
|
||||
def make_ref_addmm(alpha, beta):
|
||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||
|
@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(xp["x"].shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_muon(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||
"second": mx.zeros((3, 3)),
|
||||
"conv": mx.zeros((16, 8, 3, 3)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Test update
|
||||
updated_params = optim.apply_gradients(grads, params)
|
||||
|
||||
# Check that shapes are preserved
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, u: p.shape == u.shape,
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Check that parameters actually changed
|
||||
self.assertFalse(
|
||||
tree_equal(
|
||||
lambda p, u: mx.array_equal(p, u),
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Test with different configurations
|
||||
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||
optim_no_nesterov.apply_gradients(grads, params)
|
||||
|
||||
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||
optim_no_momentum.apply_gradients(grads, params)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
model = nn.Linear(10, 10)
|
||||
x = mx.random.uniform(shape=(2, 10))
|
||||
|
Loading…
Reference in New Issue
Block a user