Add linear warmup and schedule joining for use with existing schedules (#721)

* Add linear warmup to schedules for use with existing schedules

* Changed parameters for simplicity of most common case (0 initial value)

* Added ScheduleJoiner and updated documentation

* ScheduleJoiner -> join_schedules (ala optax #)

* black compliance

* Different evaluation of schedules

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-02-26 10:28:48 -05:00 committed by GitHub
parent e6418781ab
commit 3b661b7394
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 109 additions and 5 deletions

View File

@ -8,6 +8,8 @@ Schedulers
.. autosummary:: .. autosummary::
:toctree: _autosummary :toctree: _autosummary
step_decay
exponential_decay
cosine_decay cosine_decay
exponential_decay
join_schedules
linear_schedule
step_decay

View File

@ -1,11 +1,12 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
import math import math
from typing import Callable, List
import mlx.core as mx import mlx.core as mx
def exponential_decay(init: float, decay_rate: float): def exponential_decay(init: float, decay_rate: float) -> Callable:
r"""Make an exponential decay scheduler. r"""Make an exponential decay scheduler.
Args: Args:
@ -30,7 +31,7 @@ def exponential_decay(init: float, decay_rate: float):
return schedule return schedule
def step_decay(init: float, decay_rate: float, step_size: int): def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
r"""Make a step decay scheduler. r"""Make a step decay scheduler.
Args: Args:
@ -57,7 +58,7 @@ def step_decay(init: float, decay_rate: float, step_size: int):
return schedule return schedule
def cosine_decay(init: float, decay_steps: int): def cosine_decay(init: float, decay_steps: int) -> Callable:
r"""Make a cosine decay scheduler. r"""Make a cosine decay scheduler.
Args: Args:
@ -84,3 +85,73 @@ def cosine_decay(init: float, decay_steps: int):
return init * decay return init * decay
return scheduler return scheduler
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
r"""Join multiple schedules to create a new schedule.
Args:
schedules (list(Callable)): A list of schedules. Schedule :math:`i+1`
receives a step count indicating the number of steps since
the :math:`i`-th boundary.
boundaries (list(int)): A list of integers of length ``len(schedules) - 1``
that indicates when to transition between schedules.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
>>> cosine = optim.cosine_decay(1e-1, 200)
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(12): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.0999938, dtype=float32)
"""
if len(schedules) == 0:
raise ValueError("Must provide at least 1 schedule to join.")
if len(schedules) != len(boundaries) + 1:
raise ValueError(
f"Received {len(boundaries)} boundaries but "
f"expected {len(schedules) - 1}."
)
def schedule(step):
output = schedules[0](step)
for boundary, schedule in zip(boundaries, schedules[1:]):
output = mx.where(step < boundary, output, schedule(step - boundary))
return output
return schedule
def linear_schedule(init: float, end: float, steps: int) -> Callable:
r"""Make a linear scheduler.
Args:
init (float): Initial value.
end (float): Final value.
steps (int): Number of steps to apply the schedule over. The value is
``end`` for any steps beyond ``steps``.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=warmup)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(101): optimizer.update({}, {})
...
>>> optimizer.learning_rate
array(0.1, dtype=float32)
"""
if steps < 1:
raise ValueError(f"steps must be greater than 0, but got {steps}.")
def step_fn(step):
step = mx.minimum(step, steps)
return step * ((end - init) / steps) + init
return step_fn

View File

@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7) self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_schedule_joiner(self):
boundaries = [2, 3, 4]
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
with self.assertRaises(ValueError):
opt.schedulers.join_schedules(schedules, boundaries)
boundaries = [2, 4]
schedule = opt.schedulers.join_schedules(schedules, boundaries)
self.assertEqual(schedule(0).item(), 3)
self.assertEqual(schedule(1).item(), 3)
self.assertEqual(schedule(2).item(), 4)
self.assertEqual(schedule(3).item(), 4)
self.assertEqual(schedule(5).item(), 5)
self.assertEqual(schedule(7).item(), 5)
def test_linear_warmup_with_cosine_decay(self):
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
cos_with_warmup = opt.schedulers.join_schedules(
[warmup_schedule, cosine_schedule], [101]
)
self.assertEqual(cos_with_warmup(0), 0.0)
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
optimizer = opt.Adam(learning_rate=cos_with_warmup)
for _ in range(100):
optimizer.update({}, {})
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
for _ in range(100):
optimizer.update({}, {})
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
def test_compile_with_schedule(self): def test_compile_with_schedule(self):
lr_schedule = opt.exponential_decay(1e-1, 0.9) lr_schedule = opt.exponential_decay(1e-1, 0.9)
optimizer = opt.SGD(learning_rate=lr_schedule) optimizer = opt.SGD(learning_rate=lr_schedule)