mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add CosineAnnealingWarmRestarts scheduler
This commit is contained in:
@@ -446,6 +446,56 @@ class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
lr = lr_schedule(20)
|
||||
self.assertEqual(lr, expected_end_lr)
|
||||
|
||||
def test_cosine_annealing_warm_restarts(self):
|
||||
# Test with T_mult=1 (equal periods)
|
||||
lr_schedule = opt.cosine_annealing_warm_restarts(
|
||||
0.1, T_0=10, T_mult=1, eta_min=0.0
|
||||
)
|
||||
|
||||
# Test initial value
|
||||
lr = lr_schedule(0)
|
||||
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||
|
||||
# Test mid-cycle (should be minimum for T_mult=1)
|
||||
lr = lr_schedule(5)
|
||||
expected_lr = 0.0 + (0.1 - 0.0) * 0.5 * (1.0 + math.cos(math.pi * 5 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
# Test restart (after T_0 steps)
|
||||
lr = lr_schedule(10)
|
||||
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||
|
||||
# Test with T_mult=2 (doubling periods)
|
||||
lr_schedule = opt.cosine_annealing_warm_restarts(
|
||||
0.1, T_0=5, T_mult=2, eta_min=0.01
|
||||
)
|
||||
|
||||
# Test first cycle restart
|
||||
lr = lr_schedule(5)
|
||||
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||
|
||||
# Test second cycle (should be 10 steps long, restart at step 15)
|
||||
lr = lr_schedule(15)
|
||||
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||
|
||||
# Test with eta_min
|
||||
lr = lr_schedule(10) # Mid of second cycle
|
||||
expected_lr = 0.01 + (0.1 - 0.01) * 0.5 * (1.0 + math.cos(math.pi * 5 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
# Test error handling
|
||||
with self.assertRaises(ValueError):
|
||||
opt.cosine_annealing_warm_restarts(0.1, T_0=0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
opt.cosine_annealing_warm_restarts(0.1, T_0=10, T_mult=0)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
opt.cosine_annealing_warm_restarts(0.1, T_0=10, eta_min=-0.1)
|
||||
|
||||
with self.assertRaises(ValueError):
|
||||
opt.cosine_annealing_warm_restarts(0.01, T_0=10, eta_min=0.1)
|
||||
|
||||
def test_schedule_joiner(self):
|
||||
boundaries = [2, 3, 4]
|
||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||
|
||||
Reference in New Issue
Block a user