mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 08:29:35 +08:00
Add minimum for cosine decay function (#859)
* Add minimum for cosine decay function * Update python/mlx/optimizers/schedulers.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
@@ -328,6 +328,11 @@ class TestSchedulers(unittest.TestCase):
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
lr_schedule = opt.cosine_decay(0.1, 10, 0.05)
|
||||
lr = lr_schedule(20)
|
||||
expected_lr = 0.05
|
||||
self.assertEqual(lr, expected_lr)
|
||||
|
||||
def test_schedule_joiner(self):
|
||||
boundaries = [2, 3, 4]
|
||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||
|
Reference in New Issue
Block a user