mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-22 01:21:14 +08:00
Adding support for the Muon Optimizer (#1914)
* initial commit with workong optmimizer * update ACKNOWLEDGMENTS.md * nits and adding it to test * nits * G.astype(mx.bfloat16) to G.astype(G.dtype) * G.ndim >= 2 to assert G.ndim == 2 * remove coments * replace with mx.addmm * remove comments * format * nits * match muon * fix addmm --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
45adec102c
commit
deee214a95
@ -19,6 +19,7 @@ MLX was developed with contributions from the following individuals:
|
|||||||
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
||||||
- Paul Paczuski: Improved stability of BCE loss calculation
|
- Paul Paczuski: Improved stability of BCE loss calculation
|
||||||
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
||||||
|
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer.
|
||||||
|
|
||||||
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
||||||
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
||||||
|
@ -19,3 +19,4 @@ Common Optimizers
|
|||||||
Adamax
|
Adamax
|
||||||
Lion
|
Lion
|
||||||
MultiOptimizer
|
MultiOptimizer
|
||||||
|
Muon
|
||||||
|
@ -119,7 +119,6 @@ class MatMul {
|
|||||||
uint64_t b_rows,
|
uint64_t b_rows,
|
||||||
uint64_t b_cols,
|
uint64_t b_cols,
|
||||||
int64_t ldb,
|
int64_t ldb,
|
||||||
bool c_transposed,
|
|
||||||
int64_t ldc,
|
int64_t ldc,
|
||||||
int32_t batch_count,
|
int32_t batch_count,
|
||||||
int64_t a_batch_stride,
|
int64_t a_batch_stride,
|
||||||
@ -141,7 +140,7 @@ class MatMul {
|
|||||||
b_batch_stride) {
|
b_batch_stride) {
|
||||||
auto type = dtype_to_cuda_type(dtype);
|
auto type = dtype_to_cuda_type(dtype);
|
||||||
c_desc_ = create_matrix_layout(
|
c_desc_ = create_matrix_layout(
|
||||||
type, a_rows, b_cols, c_transposed, ldc, batch_count, c_batch_stride);
|
type, a_rows, b_cols, false, ldc, batch_count, c_batch_stride);
|
||||||
}
|
}
|
||||||
|
|
||||||
~MatMul() {
|
~MatMul() {
|
||||||
@ -403,9 +402,7 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
assert(inputs.size() == 3);
|
assert(inputs.size() == 3);
|
||||||
auto& a_pre = inputs[0];
|
auto& a_pre = inputs[0];
|
||||||
auto& b_pre = inputs[1];
|
auto& b_pre = inputs[1];
|
||||||
auto& c_pre = inputs[2];
|
auto c = inputs[2];
|
||||||
|
|
||||||
out.set_data(allocator::malloc(out.nbytes()));
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Init checks and prep
|
// Init checks and prep
|
||||||
@ -418,7 +415,24 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
// the arrays
|
// the arrays
|
||||||
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
auto [a_transposed, lda, a] = check_transpose(encoder, s, a_pre);
|
||||||
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
auto [b_transposed, ldb, b] = check_transpose(encoder, s, b_pre);
|
||||||
auto [c_transposed, ldc, c] = check_transpose(encoder, s, c_pre);
|
|
||||||
|
int64_t ldc;
|
||||||
|
{
|
||||||
|
auto stx = c.strides()[c.ndim() - 2];
|
||||||
|
auto sty = c.strides()[c.ndim() - 1];
|
||||||
|
if (sty == 1 && stx == c.shape(-1)) {
|
||||||
|
ldc = stx;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
} else if (sty == 1 && stx == 0) {
|
||||||
|
ldc = 0;
|
||||||
|
out.set_data(allocator::malloc(out.nbytes()));
|
||||||
|
} else {
|
||||||
|
// Copy C into out and set C to out
|
||||||
|
ldc = c.shape(-1);
|
||||||
|
copy_gpu(c, out, CopyType::General, s);
|
||||||
|
c = out;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////
|
||||||
// Check and collapse batch dimensions
|
// Check and collapse batch dimensions
|
||||||
@ -456,7 +470,6 @@ void AddMM::eval_gpu(const std::vector<array>& inputs, array& out) {
|
|||||||
K,
|
K,
|
||||||
N,
|
N,
|
||||||
ldb,
|
ldb,
|
||||||
c_transposed,
|
|
||||||
ldc,
|
ldc,
|
||||||
batch_shape.back(),
|
batch_shape.back(),
|
||||||
a_batch_strides.back(),
|
a_batch_strides.back(),
|
||||||
|
@ -848,6 +848,106 @@ class Adafactor(Optimizer):
|
|||||||
return parameter - update
|
return parameter - update
|
||||||
|
|
||||||
|
|
||||||
|
class Muon(Optimizer):
|
||||||
|
r"""The Muon optimizer.
|
||||||
|
|
||||||
|
Our Muon (MomentUm Orthogonalized by Newton-schulz) optimizer follows the
|
||||||
|
original implementation: `Muon: An optimizer for hidden layers in neural
|
||||||
|
networks <https://kellerjordan.github.io/posts/muon/>`_
|
||||||
|
|
||||||
|
Note:
|
||||||
|
- Muon may be sub-optimal for the embedding layer, the final fully
|
||||||
|
connected layer, or any 0D/1D parameters. Those should be optimized
|
||||||
|
by a different method (e.g., :class:`AdamW`).
|
||||||
|
- For 4D convolutional filters, it works by flattening their last
|
||||||
|
dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
learning_rate (float or callable): The learning rate.
|
||||||
|
momentum (float, optional): The momentum strength. Default: ``0.95``
|
||||||
|
weight_decay (float, optional): The weight decay (L2 penalty).
|
||||||
|
Default: ``0.01``
|
||||||
|
nesterov (bool, optional): Enables Nesterov momentum. Recommended for
|
||||||
|
better performance. Default: ``True``
|
||||||
|
ns_steps (int, optional): Number of Newton-Schulz iteration steps for
|
||||||
|
orthogonalization. Default: ``5``
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
learning_rate: Union[float, Callable[[mx.array], mx.array]],
|
||||||
|
momentum: float = 0.95,
|
||||||
|
weight_decay: float = 0.01,
|
||||||
|
nesterov: bool = True,
|
||||||
|
ns_steps: int = 5,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
self._maybe_schedule("learning_rate", learning_rate)
|
||||||
|
self.momentum = momentum
|
||||||
|
self.weight_decay = weight_decay
|
||||||
|
self.nesterov = nesterov
|
||||||
|
self.ns_steps = ns_steps
|
||||||
|
|
||||||
|
def init_single(self, parameter: mx.array, state: dict):
|
||||||
|
"""Initialize optimizer state"""
|
||||||
|
state["v"] = mx.zeros_like(parameter)
|
||||||
|
|
||||||
|
def _zeropower_via_newtonschulz5(self, X, steps: int):
|
||||||
|
assert (
|
||||||
|
X.ndim == 2
|
||||||
|
), f"Expected a 2D array for Newton-Schulz iteration, got shape {X.shape} instead."
|
||||||
|
a, b, c = (3.4445, -4.7750, 2.0315)
|
||||||
|
transpose_needed = X.shape[-2] > X.shape[-1]
|
||||||
|
|
||||||
|
if transpose_needed:
|
||||||
|
X = X.T
|
||||||
|
|
||||||
|
X = X / (mx.linalg.norm(X, keepdims=True) + 1e-7)
|
||||||
|
|
||||||
|
for _ in range(steps):
|
||||||
|
A = X @ X.T
|
||||||
|
B = mx.addmm(b * A, A, A, beta=1.0, alpha=c)
|
||||||
|
X = mx.addmm(a * X, B, X, beta=1.0, alpha=1.0)
|
||||||
|
|
||||||
|
if transpose_needed:
|
||||||
|
X = X.T
|
||||||
|
return X
|
||||||
|
|
||||||
|
def apply_single(self, gradient: mx.array, parameter: mx.array, state: dict):
|
||||||
|
"""Performs the Muon parameter update"""
|
||||||
|
|
||||||
|
if self.weight_decay != 0:
|
||||||
|
gradient = gradient + self.weight_decay * parameter
|
||||||
|
|
||||||
|
v = self.momentum * state["v"]
|
||||||
|
v = v + (1 - self.momentum) * gradient
|
||||||
|
state["v"] = v
|
||||||
|
|
||||||
|
if self.nesterov:
|
||||||
|
update = gradient * (1 - self.momentum) + v * self.momentum
|
||||||
|
else:
|
||||||
|
update = v
|
||||||
|
|
||||||
|
lr = self.learning_rate.astype(gradient.dtype)
|
||||||
|
|
||||||
|
if update.ndim >= 2:
|
||||||
|
original_shape = update.shape
|
||||||
|
reshape_needed = update.ndim > 2
|
||||||
|
|
||||||
|
if reshape_needed:
|
||||||
|
update = mx.reshape(update, (update.shape[0], -1))
|
||||||
|
|
||||||
|
update = self._zeropower_via_newtonschulz5(update, steps=self.ns_steps)
|
||||||
|
|
||||||
|
if reshape_needed:
|
||||||
|
update = mx.reshape(update, original_shape)
|
||||||
|
|
||||||
|
lr *= max(1, update.shape[-2] / update.shape[-1]) ** 0.5
|
||||||
|
|
||||||
|
return parameter - lr * update
|
||||||
|
|
||||||
|
|
||||||
def clip_grad_norm(grads, max_norm):
|
def clip_grad_norm(grads, max_norm):
|
||||||
"""Clips the global norm of the gradients.
|
"""Clips the global norm of the gradients.
|
||||||
|
|
||||||
|
@ -691,6 +691,21 @@ class TestBlas(mlx_tests.MLXTestCase):
|
|||||||
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
self.assertListEqual(list(d_npy.shape), list(d_mlx.shape))
|
||||||
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
self.assertTrue(np.allclose(d_mlx, d_npy, atol=1e-5))
|
||||||
|
|
||||||
|
# Transposed c
|
||||||
|
a = mx.ones((10, 5)).T
|
||||||
|
b = mx.ones((5, 5))
|
||||||
|
out = mx.addmm(a, b, a, beta=1.5, alpha=0.5)
|
||||||
|
expected = 1.5 * a + 0.5 * (b @ a)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
|
# Broadcast c
|
||||||
|
a = mx.ones((5, 5))
|
||||||
|
b = mx.ones((5, 5))
|
||||||
|
c = mx.ones((1, 5))
|
||||||
|
out = mx.addmm(c, a, b, beta=1.5, alpha=0.5)
|
||||||
|
expected = 1.5 * c + 0.5 * (a @ b)
|
||||||
|
self.assertTrue(mx.allclose(expected, out))
|
||||||
|
|
||||||
def test_addmm_grad(self):
|
def test_addmm_grad(self):
|
||||||
def make_ref_addmm(alpha, beta):
|
def make_ref_addmm(alpha, beta):
|
||||||
return lambda c, a, b: alpha * (a @ b) + beta * c
|
return lambda c, a, b: alpha * (a @ b) + beta * c
|
||||||
|
@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(xp["x"].shape, x.shape)
|
self.assertEqual(xp["x"].shape, x.shape)
|
||||||
self.assertEqual(optimizer.state["step"], 2)
|
self.assertEqual(optimizer.state["step"], 2)
|
||||||
|
|
||||||
|
def test_muon(self):
|
||||||
|
params = {
|
||||||
|
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||||
|
"second": mx.zeros((3, 3)),
|
||||||
|
"conv": mx.zeros((16, 8, 3, 3)),
|
||||||
|
}
|
||||||
|
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||||
|
|
||||||
|
# Explicit init
|
||||||
|
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||||
|
optim.init(params)
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||||
|
params,
|
||||||
|
optim.state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test update
|
||||||
|
updated_params = optim.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
# Check that shapes are preserved
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, u: p.shape == u.shape,
|
||||||
|
params,
|
||||||
|
updated_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that parameters actually changed
|
||||||
|
self.assertFalse(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, u: mx.array_equal(p, u),
|
||||||
|
params,
|
||||||
|
updated_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with different configurations
|
||||||
|
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||||
|
optim_no_nesterov.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||||
|
optim_no_momentum.apply_gradients(grads, params)
|
||||||
|
|
||||||
def test_compiled_optimizer(self):
|
def test_compiled_optimizer(self):
|
||||||
model = nn.Linear(10, 10)
|
model = nn.Linear(10, 10)
|
||||||
x = mx.random.uniform(shape=(2, 10))
|
x = mx.random.uniform(shape=(2, 10))
|
||||||
|
Loading…
Reference in New Issue
Block a user