Add linear warmup and schedule joining for use with existing schedules (#721)

* Add linear warmup to schedules for use with existing schedules

* Changed parameters for simplicity of most common case (0 initial value)

* Added ScheduleJoiner and updated documentation

* ScheduleJoiner -> join_schedules (ala optax #)

* black compliance

* Different evaluation of schedules

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-02-26 10:28:48 -05:00
committed by GitHub
parent e6418781ab
commit 3b661b7394
3 changed files with 109 additions and 5 deletions

View File

@@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_schedule_joiner(self):
boundaries = [2, 3, 4]
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
with self.assertRaises(ValueError):
opt.schedulers.join_schedules(schedules, boundaries)
boundaries = [2, 4]
schedule = opt.schedulers.join_schedules(schedules, boundaries)
self.assertEqual(schedule(0).item(), 3)
self.assertEqual(schedule(1).item(), 3)
self.assertEqual(schedule(2).item(), 4)
self.assertEqual(schedule(3).item(), 4)
self.assertEqual(schedule(5).item(), 5)
self.assertEqual(schedule(7).item(), 5)
def test_linear_warmup_with_cosine_decay(self):
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
cos_with_warmup = opt.schedulers.join_schedules(
[warmup_schedule, cosine_schedule], [101]
)
self.assertEqual(cos_with_warmup(0), 0.0)
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
optimizer = opt.Adam(learning_rate=cos_with_warmup)
for _ in range(100):
optimizer.update({}, {})
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
for _ in range(100):
optimizer.update({}, {})
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
def test_compile_with_schedule(self):
lr_schedule = opt.exponential_decay(1e-1, 0.9)
optimizer = opt.SGD(learning_rate=lr_schedule)