diff --git a/docs/src/python/optimizers/schedulers.rst b/docs/src/python/optimizers/schedulers.rst index 50855e1e7..97f5f9874 100644 --- a/docs/src/python/optimizers/schedulers.rst +++ b/docs/src/python/optimizers/schedulers.rst @@ -8,8 +8,9 @@ Schedulers .. autosummary:: :toctree: _autosummary - cosine_decay - exponential_decay + cosine_decay + cyclic_lr + exponential_decay join_schedules linear_schedule step_decay diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py index 67e4e29cd..69d71a6a3 100644 --- a/python/mlx/optimizers/schedulers.py +++ b/python/mlx/optimizers/schedulers.py @@ -1,7 +1,7 @@ # Copyright © 2023-2024 Apple Inc. import math -from typing import Callable, List +from typing import Callable, List, Optional import mlx.core as mx @@ -156,3 +156,64 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable: return step * ((end - init) / steps) + init return schedule + + +def cyclic_lr( + base_lr: float, + max_lr: float, + step_size_up: int = 2000, + step_size_down: Optional[int] = None, + mode: str = "triangular", + gamma: float = 1.0, +) -> Callable: + r"""Make a cyclic learning rate scheduler. + + Args: + base_lr (float): Lower learning rate boundary. + max_lr (float): Upper learning rate boundary. + step_size_up (int): Number of steps in the increasing half of a cycle. Default: ``2000``. + step_size_down (int, optional): Number of steps in the decreasing half. + If ``None``, equals ``step_size_up``. Default: ``None``. + mode (str): One of ``"triangular"``, ``"triangular2"``, ``"exp_range"``. Default: ``"triangular"``. + gamma (float): Scaling factor for ``"exp_range"`` mode. Default: ``1.0``. + + Example: + >>> lr_schedule = optim.cyclic_lr(0.001, 0.1, step_size_up=10) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.001, dtype=float32) + >>> + >>> for _ in range(5): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.0505, dtype=float32) + """ + step_size_down = step_size_down if step_size_down is not None else step_size_up + total_size = step_size_up + step_size_down + step_ratio = step_size_up / total_size + + def schedule(step): + if isinstance(step, mx.array): + step_val = step.item() if step.size == 1 else step + else: + step_val = step + + cycle = math.floor(1 + step_val / total_size) + x = 1.0 + step_val / total_size - cycle + + if x <= step_ratio: + scale_factor = x / step_ratio + else: + scale_factor = (x - 1) / (step_ratio - 1) + + if mode == "triangular": + scale_fn_val = 1.0 + elif mode == "triangular2": + scale_fn_val = 1 / (2.0 ** (cycle - 1)) + else: # exp_range + scale_fn_val = gamma ** (cycle - 1) + + base_height = (max_lr - base_lr) * scale_factor + return base_lr + base_height * scale_fn_val + + return schedule diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 6869ac357..71ba923bb 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -446,6 +446,23 @@ class TestSchedulers(mlx_tests.MLXTestCase): lr = lr_schedule(20) self.assertEqual(lr, expected_end_lr) + def test_cyclic_lr(self): + lr_schedule = opt.cyclic_lr(0.001, 0.1, step_size_up=10) + + lr = lr_schedule(0) + self.assertAlmostEqual(lr, 0.001, delta=1e-7) + + lr = lr_schedule(10) + self.assertAlmostEqual(lr, 0.1, delta=1e-7) + + lr = lr_schedule(20) + self.assertAlmostEqual(lr, 0.001, delta=1e-7) + + lr_schedule = opt.cyclic_lr(0.001, 0.1, step_size_up=5, mode="triangular2") + lr = lr_schedule(15) + expected_lr = 0.001 + (0.1 - 0.001) * 0.5 + self.assertAlmostEqual(lr, expected_lr, delta=1e-6) + def test_schedule_joiner(self): boundaries = [2, 3, 4] schedules = [lambda _: 3, lambda _: 4, lambda _: 5]