Add cyclic_lr scheduler

This commit is contained in:
Vincent Amato
2025-08-11 20:40:33 -04:00
parent 7fde1b6a1e
commit 1d9ce9d744
3 changed files with 82 additions and 3 deletions

View File

@@ -446,6 +446,23 @@ class TestSchedulers(mlx_tests.MLXTestCase):
lr = lr_schedule(20)
self.assertEqual(lr, expected_end_lr)
def test_cyclic_lr(self):
lr_schedule = opt.cyclic_lr(0.001, 0.1, step_size_up=10)
lr = lr_schedule(0)
self.assertAlmostEqual(lr, 0.001, delta=1e-7)
lr = lr_schedule(10)
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
lr = lr_schedule(20)
self.assertAlmostEqual(lr, 0.001, delta=1e-7)
lr_schedule = opt.cyclic_lr(0.001, 0.1, step_size_up=5, mode="triangular2")
lr = lr_schedule(15)
expected_lr = 0.001 + (0.1 - 0.001) * 0.5
self.assertAlmostEqual(lr, expected_lr, delta=1e-6)
def test_schedule_joiner(self):
boundaries = [2, 3, 4]
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]