mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
nits
This commit is contained in:
@@ -307,7 +307,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
|
||||
# Test update
|
||||
updated_params = optim.apply_gradients(grads, params)
|
||||
|
||||
|
||||
# Check that shapes are preserved
|
||||
self.assertTrue(
|
||||
tree_equal(
|
||||
@@ -316,7 +316,7 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Check that parameters actually changed
|
||||
self.assertFalse(
|
||||
tree_equal(
|
||||
@@ -325,11 +325,11 @@ class TestOptimizers(mlx_tests.MLXTestCase):
|
||||
updated_params,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
# Test with different configurations
|
||||
optim_no_nesterov = opt.Muon(learning_rate=1e-2, momentum=0.95, nesterov=False)
|
||||
optim_no_nesterov.apply_gradients(grads, params)
|
||||
|
||||
|
||||
optim_no_momentum = opt.Muon(learning_rate=1e-2, momentum=0.0)
|
||||
optim_no_momentum.apply_gradients(grads, params)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user