mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
nits and adding it to test
This commit is contained in:
@@ -286,6 +286,53 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
self.assertEqual(xp["x"].shape, x.shape)
|
||||
self.assertEqual(optimizer.state["step"], 2)
|
||||
|
||||
def test_muon(self):
|
||||
params = {
|
||||
"first": [mx.zeros((10, 5)), mx.zeros((1,))],
|
||||
"second": mx.zeros((3, 3)),
|
||||
"conv": mx.zeros((16, 8, 3, 3)),
|
||||
}
|
||||
grads = tree_map(lambda x: mx.ones_like(x), params)
|
||||
|
||||
# Explicit init
|
||||
optim = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=True)
|
||||
optim.init(params)
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, s: mx.array_equal(s["v"], mx.zeros_like(p)),
|
||||
params,
|
||||
optim.state,
|
||||
)
|
||||
)
|
||||
|
||||
# Test update
|
||||
updated_params = optim.apply_gradients(grads, params)
|
||||
|
||||
# Check that shapes are preserved
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
lambda p, u: p.shape == u.shape,
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Check that parameters actually changed
|
||||
self.assertFalse(
|
||||
tree_equal(
|
||||
lambda p, u: mx.array_equal(p, u),
|
||||
params,
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
# Test with different configurations
|
||||
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||
optim_no_nesterov.apply_gradients(grads, params)
|
||||
|
||||
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||
optim_no_momentum.apply_gradients(grads, params)
|
||||
|
||||
def test_compiled_optimizer(self):
|
||||
model = nn.Linear(10, 10)
|
||||
x = mx.random.uniform(shape=(2, 10))
|
||||
|
||||
Reference in New Issue
Block a user