mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add cyclic_lr scheduler
This commit is contained in:
@@ -446,6 +446,23 @@ class TestSchedulers(mlx_tests.MLXTestCase):
|
||||
lr = lr_schedule(20)
|
||||
self.assertEqual(lr, expected_end_lr)
|
||||
|
||||
def test_cyclic_lr(self):
|
||||
lr_schedule = opt.cyclic_lr(0.001, 0.1, step_size_up=10)
|
||||
|
||||
lr = lr_schedule(0)
|
||||
self.assertAlmostEqual(lr, 0.001, delta=1e-7)
|
||||
|
||||
lr = lr_schedule(10)
|
||||
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||
|
||||
lr = lr_schedule(20)
|
||||
self.assertAlmostEqual(lr, 0.001, delta=1e-7)
|
||||
|
||||
lr_schedule = opt.cyclic_lr(0.001, 0.1, step_size_up=5, mode="triangular2")
|
||||
lr = lr_schedule(15)
|
||||
expected_lr = 0.001 + (0.1 - 0.001) * 0.5
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-6)
|
||||
|
||||
def test_schedule_joiner(self):
|
||||
boundaries = [2, 3, 4]
|
||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||
|
||||
Reference in New Issue
Block a user