diff --git a/docs/src/python/optimizers/schedulers.rst b/docs/src/python/optimizers/schedulers.rst index a83883ddb..50855e1e7 100644 --- a/docs/src/python/optimizers/schedulers.rst +++ b/docs/src/python/optimizers/schedulers.rst @@ -8,6 +8,8 @@ Schedulers .. autosummary:: :toctree: _autosummary - step_decay - exponential_decay cosine_decay + exponential_decay + join_schedules + linear_schedule + step_decay diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py index da058c03a..d4bf5e126 100644 --- a/python/mlx/optimizers/schedulers.py +++ b/python/mlx/optimizers/schedulers.py @@ -1,11 +1,12 @@ # Copyright © 2023-2024 Apple Inc. import math +from typing import Callable, List import mlx.core as mx -def exponential_decay(init: float, decay_rate: float): +def exponential_decay(init: float, decay_rate: float) -> Callable: r"""Make an exponential decay scheduler. Args: @@ -30,7 +31,7 @@ def exponential_decay(init: float, decay_rate: float): return schedule -def step_decay(init: float, decay_rate: float, step_size: int): +def step_decay(init: float, decay_rate: float, step_size: int) -> Callable: r"""Make a step decay scheduler. Args: @@ -57,7 +58,7 @@ def step_decay(init: float, decay_rate: float, step_size: int): return schedule -def cosine_decay(init: float, decay_steps: int): +def cosine_decay(init: float, decay_steps: int) -> Callable: r"""Make a cosine decay scheduler. Args: @@ -84,3 +85,73 @@ def cosine_decay(init: float, decay_steps: int): return init * decay return scheduler + + +def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable: + r"""Join multiple schedules to create a new schedule. + + Args: + schedules (list(Callable)): A list of schedules. Schedule :math:`i+1` + receives a step count indicating the number of steps since + the :math:`i`-th boundary. + boundaries (list(int)): A list of integers of length ``len(schedules) - 1`` + that indicates when to transition between schedules. + + Example: + >>> warmup = optim.linear_schedule(0, 1e-1, steps=10) + >>> cosine = optim.cosine_decay(1e-1, 200) + >>> lr_schedule = optim.join_schedules([warmup, cosine], [10]) + >>> optimizer = optim.Adam(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.0, dtype=float32) + >>> for _ in range(12): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.0999938, dtype=float32) + """ + if len(schedules) == 0: + raise ValueError("Must provide at least 1 schedule to join.") + + if len(schedules) != len(boundaries) + 1: + raise ValueError( + f"Received {len(boundaries)} boundaries but " + f"expected {len(schedules) - 1}." + ) + + def schedule(step): + output = schedules[0](step) + for boundary, schedule in zip(boundaries, schedules[1:]): + output = mx.where(step < boundary, output, schedule(step - boundary)) + return output + + return schedule + + +def linear_schedule(init: float, end: float, steps: int) -> Callable: + r"""Make a linear scheduler. + + Args: + init (float): Initial value. + end (float): Final value. + steps (int): Number of steps to apply the schedule over. The value is + ``end`` for any steps beyond ``steps``. + + Example: + + >>> warmup = optim.linear_schedule(0, 1e-1, 100) + >>> optimizer = optim.Adam(learning_rate=warmup) + >>> optimizer.learning_rate + array(0.0, dtype=float32) + >>> for _ in range(101): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.1, dtype=float32) + """ + if steps < 1: + raise ValueError(f"steps must be greater than 0, but got {steps}.") + + def step_fn(step): + step = mx.minimum(step, steps) + return step * ((end - init) / steps) + init + + return step_fn diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index f978943de..5c28938dc 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase): expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + def test_schedule_joiner(self): + boundaries = [2, 3, 4] + schedules = [lambda _: 3, lambda _: 4, lambda _: 5] + with self.assertRaises(ValueError): + opt.schedulers.join_schedules(schedules, boundaries) + boundaries = [2, 4] + schedule = opt.schedulers.join_schedules(schedules, boundaries) + self.assertEqual(schedule(0).item(), 3) + self.assertEqual(schedule(1).item(), 3) + self.assertEqual(schedule(2).item(), 4) + self.assertEqual(schedule(3).item(), 4) + self.assertEqual(schedule(5).item(), 5) + self.assertEqual(schedule(7).item(), 5) + + def test_linear_warmup_with_cosine_decay(self): + warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100) + cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100) + cos_with_warmup = opt.schedulers.join_schedules( + [warmup_schedule, cosine_schedule], [101] + ) + self.assertEqual(cos_with_warmup(0), 0.0) + self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1) + optimizer = opt.Adam(learning_rate=cos_with_warmup) + for _ in range(100): + optimizer.update({}, {}) + self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1) + for _ in range(100): + optimizer.update({}, {}) + expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10)) + self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1) + def test_compile_with_schedule(self): lr_schedule = opt.exponential_decay(1e-1, 0.9) optimizer = opt.SGD(learning_rate=lr_schedule)