mirror of
https://github.com/ml-explore/mlx.git
synced 2025-10-19 00:04:41 +08:00
Add linear warmup and schedule joining for use with existing schedules (#721)
* Add linear warmup to schedules for use with existing schedules * Changed parameters for simplicity of most common case (0 initial value) * Added ScheduleJoiner and updated documentation * ScheduleJoiner -> join_schedules (ala optax #) * black compliance * Different evaluation of schedules * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
|
||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
def test_schedule_joiner(self):
|
||||
boundaries = [2, 3, 4]
|
||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||
with self.assertRaises(ValueError):
|
||||
opt.schedulers.join_schedules(schedules, boundaries)
|
||||
boundaries = [2, 4]
|
||||
schedule = opt.schedulers.join_schedules(schedules, boundaries)
|
||||
self.assertEqual(schedule(0).item(), 3)
|
||||
self.assertEqual(schedule(1).item(), 3)
|
||||
self.assertEqual(schedule(2).item(), 4)
|
||||
self.assertEqual(schedule(3).item(), 4)
|
||||
self.assertEqual(schedule(5).item(), 5)
|
||||
self.assertEqual(schedule(7).item(), 5)
|
||||
|
||||
def test_linear_warmup_with_cosine_decay(self):
|
||||
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
|
||||
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
|
||||
cos_with_warmup = opt.schedulers.join_schedules(
|
||||
[warmup_schedule, cosine_schedule], [101]
|
||||
)
|
||||
self.assertEqual(cos_with_warmup(0), 0.0)
|
||||
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
|
||||
optimizer = opt.Adam(learning_rate=cos_with_warmup)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
|
||||
for _ in range(100):
|
||||
optimizer.update({}, {})
|
||||
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
|
||||
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
|
||||
|
||||
def test_compile_with_schedule(self):
|
||||
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||
|
Reference in New Issue
Block a user