From 7e5674d8be8b28a6c2f2074ec4093efdd8fda658 Mon Sep 17 00:00:00 2001 From: jlwitthuhn <5405091+jlwitthuhn@users.noreply.github.com> Date: Mon, 20 May 2024 08:00:48 -0700 Subject: [PATCH] Treate 'minimum' differently in cosine decay (#1138) --- python/mlx/optimizers/schedulers.py | 6 +++--- python/tests/test_optimizers.py | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py index 1481a7499..a8c0354f3 100644 --- a/python/mlx/optimizers/schedulers.py +++ b/python/mlx/optimizers/schedulers.py @@ -58,14 +58,14 @@ def step_decay(init: float, decay_rate: float, step_size: int) -> Callable: return schedule -def cosine_decay(init: float, decay_steps: int, minimum: float = 0.0) -> Callable: +def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable: r"""Make a cosine decay scheduler. Args: init (float): Initial value. decay_steps (int): Number of steps to decay over. The decayed value is constant for steps beyond ``decay_steps``. - minimum (float, optional): Minimal value to decay to. Default: ``0``. + end (float, optional): Final value to decay to. Default: ``0``. Example: @@ -83,7 +83,7 @@ def cosine_decay(init: float, decay_steps: int, minimum: float = 0.0) -> Callabl def scheduler(step): s = mx.minimum(step, decay_steps) decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s)) - return mx.maximum(init * decay, minimum) + return end + decay * (init - end) return scheduler diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 950850b1f..1a6e5e431 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -329,9 +329,11 @@ class TestSchedulers(unittest.TestCase): self.assertAlmostEqual(lr, expected_lr, delta=1e-7) lr_schedule = opt.cosine_decay(0.1, 10, 0.05) + lr = lr_schedule(9) + expected_end_lr = 0.05 + self.assertGreater(lr, expected_end_lr) lr = lr_schedule(20) - expected_lr = 0.05 - self.assertEqual(lr, expected_lr) + self.assertEqual(lr, expected_end_lr) def test_schedule_joiner(self): boundaries = [2, 3, 4]