mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
nits and adding it to test
This commit is contained in:
@@ -933,13 +933,13 @@ class Muon(Optimizer):
|
|||||||
gradient = gradient + self.weight_decay * parameter
|
gradient = gradient + self.weight_decay * parameter
|
||||||
|
|
||||||
# Update momentum buffer
|
# Update momentum buffer
|
||||||
v = self.momentum * state.get("v")
|
v = self.momentum * state["v"]
|
||||||
v = v + (1 - self.momentum) * gradient
|
v = v + (1 - self.momentum) * gradient
|
||||||
state["v"] = v
|
state["v"] = v
|
||||||
|
|
||||||
# Get effective gradient
|
# Get effective gradient
|
||||||
if self.nesterov:
|
if self.nesterov:
|
||||||
effective_grad = gradient * self.momentum + v * (1 - self.momentum)
|
effective_grad = gradient * (1 - self.momentum) + v * self.momentum
|
||||||
else:
|
else:
|
||||||
effective_grad = v
|
effective_grad = v
|
||||||
|
|
||||||
@@ -963,7 +963,8 @@ class Muon(Optimizer):
|
|||||||
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
orthogonalized_grad = mx.reshape(orthogonalized_grad, original_shape)
|
||||||
|
|
||||||
# Calculate scaling factor
|
# Calculate scaling factor
|
||||||
scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
# scale_factor = max(1, parameter.shape[-2] / parameter.shape[-1]) ** 0.5
|
||||||
|
scale_factor = max(1, effective_grad.shape[-2] / effective_grad.shape[-1]) ** 0.5
|
||||||
|
|
||||||
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
|
return parameter - self.learning_rate.astype(gradient.dtype) * orthogonalized_grad * scale_factor
|
||||||
|
|
||||||
|
|||||||
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
|||||||
self.assertEqual(xp["x"].shape, x.shape)
|
self.assertEqual(xp["x"].shape, x.shape)
|
||||||
self.assertEqual(optimizer.state["step"], 2)
|
self.assertEqual(optimizer.state["step"], 2)
|
||||||
|
|
||||||
|
def test_muon(self):
|
||||||
|
params = {
|
||||||
|
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||||
|
"second": mx.zeros((3, 3)),
|
||||||
|
"conv": mx.zeros((16, 8, 3, 3)),
|
||||||
|
}
|
||||||
|
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||||
|
|
||||||
|
# Explicit init
|
||||||
|
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||||
|
optim.init(params)
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||||
|
params,
|
||||||
|
optim.state,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test update
|
||||||
|
updated_params = optim.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
# Check that shapes are preserved
|
||||||
|
self.assertTrue(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, u: p.shape == u.shape,
|
||||||
|
params,
|
||||||
|
updated_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check that parameters actually changed
|
||||||
|
self.assertFalse(
|
||||||
|
tree_equal(
|
||||||
|
lambda p, u: mx.array_equal(p, u),
|
||||||
|
params,
|
||||||
|
updated_params,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test with different configurations
|
||||||
|
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||||
|
optim_no_nesterov.apply_gradients(grads, params)
|
||||||
|
|
||||||
|
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||||
|
optim_no_momentum.apply_gradients(grads, params)
|
||||||
|
|
||||||
def test_compiled_optimizer(self):
|
def test_compiled_optimizer(self):
|
||||||
model = nn.Linear(10, 10)
|
model = nn.Linear(10, 10)
|
||||||
x = mx.random.uniform(shape=(2, 10))
|
x = mx.random.uniform(shape=(2, 10))
|
||||||
|
|||||||
Reference in New Issue
Block a user