This commit is contained in:
Awni Hannun
2025-07-17 06:26:43 -07:00
parent baad6e392b
commit 7f39e9c299
3 changed files with 50 additions and 47 deletions

View File

@@ -307,7 +307,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
# Test update
updated_params = optim.apply_gradients(grads, params)
# Check that shapes are preserved
self.assertTrue(
tree_equal(
@@ -316,7 +316,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
updated_params,
)
)
# Check that parameters actually changed
self.assertFalse(
tree_equal(
@@ -325,11 +325,11 @@ class TestOptimizers(mlx_tests.MLXTestCase):
updated_params,
)
)
# Test with different configurations
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
optim_no_nesterov.apply_gradients(grads, params)
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
optim_no_momentum.apply_gradients(grads, params)