Add minimum for cosine decay function (#859)

* Add minimum for cosine decay function

* Update python/mlx/optimizers/schedulers.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Chime Ogbuji
2024-03-21 10:33:29 -04:00
committed by GitHub
parent a54f06b16f
commit f5a1582fe8
2 changed files with 8 additions and 2 deletions

View File

@@ -328,6 +328,11 @@ class TestSchedulers(unittest.TestCase):
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
lr_schedule = opt.cosine_decay(0.1, 10, 0.05)
lr = lr_schedule(20)
expected_lr = 0.05
self.assertEqual(lr, expected_lr)
def test_schedule_joiner(self):
boundaries = [2, 3, 4]
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]