From 84ef89f548b9d855d8fc8db0a724ac27537ec88d Mon Sep 17 00:00:00 2001 From: Vincent Amato Date: Mon, 11 Aug 2025 19:17:47 -0400 Subject: [PATCH] Add CosineAnnealingWarmRestarts scheduler --- python/mlx/optimizers/schedulers.py | 57 +++++++++++++++++++++++++++++ python/tests/test_optimizers.py | 50 +++++++++++++++++++++++++ 2 files changed, 107 insertions(+) diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py index 67e4e29cd..7e6804f2b 100644 --- a/python/mlx/optimizers/schedulers.py +++ b/python/mlx/optimizers/schedulers.py @@ -156,3 +156,60 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable: return step * ((end - init) / steps) + init return schedule + + +def cosine_annealing_warm_restarts( + init: float, T_0: int, T_mult: int = 1, eta_min: float = 0.0 +) -> Callable: + r"""Make a cosine annealing scheduler with warm restarts. + + The learning rate anneals using a cosine schedule and resets periodically. + + Args: + init (float): Initial learning rate. + T_0 (int): Number of steps for the first restart. + T_mult (int, optional): Factor to increase the period after each restart. Default: ``1``. + eta_min (float, optional): Minimum learning rate. Default: ``0.0``. + + Example: + >>> lr_schedule = optim.cosine_annealing_warm_restarts(1e-1, T_0=10) + >>> optimizer = optim.SGD(learning_rate=lr_schedule) + >>> optimizer.learning_rate + array(0.1, dtype=float32) + >>> + >>> for _ in range(11): optimizer.update({}, {}) + ... + >>> optimizer.learning_rate + array(0.1, dtype=float32) + """ + if T_0 < 1: + raise ValueError(f"T_0 must be at least 1, got {T_0}") + if T_mult < 1: + raise ValueError(f"T_mult must be at least 1, got {T_mult}") + if eta_min < 0: + raise ValueError(f"eta_min must be non-negative, got {eta_min}") + if init < eta_min: + raise ValueError(f"init must be >= eta_min, got init={init}, eta_min={eta_min}") + + def schedule(step): + if isinstance(step, mx.array): + step_val = step.item() if step.size == 1 else step + else: + step_val = step + + if T_mult == 1: + T_cur = step_val % T_0 + T_i = T_0 + else: + if step_val >= T_0: + n = int(math.log((step_val / T_0 * (T_mult - 1) + 1), T_mult)) + T_cur = step_val - T_0 * (T_mult**n - 1) / (T_mult - 1) + T_i = T_0 * T_mult**n + else: + T_i = T_0 + T_cur = step_val + + cos_inner = math.pi * T_cur / T_i + return eta_min + (init - eta_min) * 0.5 * (1.0 + mx.cos(cos_inner)) + + return schedule diff --git a/python/tests/test_optimizers.py b/python/tests/test_optimizers.py index 6869ac357..080377678 100644 --- a/python/tests/test_optimizers.py +++ b/python/tests/test_optimizers.py @@ -446,6 +446,56 @@ class TestSchedulers(mlx_tests.MLXTestCase): lr = lr_schedule(20) self.assertEqual(lr, expected_end_lr) + def test_cosine_annealing_warm_restarts(self): + # Test with T_mult=1 (equal periods) + lr_schedule = opt.cosine_annealing_warm_restarts( + 0.1, T_0=10, T_mult=1, eta_min=0.0 + ) + + # Test initial value + lr = lr_schedule(0) + self.assertAlmostEqual(lr, 0.1, delta=1e-7) + + # Test mid-cycle (should be minimum for T_mult=1) + lr = lr_schedule(5) + expected_lr = 0.0 + (0.1 - 0.0) * 0.5 * (1.0 + math.cos(math.pi * 5 / 10)) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + # Test restart (after T_0 steps) + lr = lr_schedule(10) + self.assertAlmostEqual(lr, 0.1, delta=1e-7) + + # Test with T_mult=2 (doubling periods) + lr_schedule = opt.cosine_annealing_warm_restarts( + 0.1, T_0=5, T_mult=2, eta_min=0.01 + ) + + # Test first cycle restart + lr = lr_schedule(5) + self.assertAlmostEqual(lr, 0.1, delta=1e-7) + + # Test second cycle (should be 10 steps long, restart at step 15) + lr = lr_schedule(15) + self.assertAlmostEqual(lr, 0.1, delta=1e-7) + + # Test with eta_min + lr = lr_schedule(10) # Mid of second cycle + expected_lr = 0.01 + (0.1 - 0.01) * 0.5 * (1.0 + math.cos(math.pi * 5 / 10)) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + # Test error handling + with self.assertRaises(ValueError): + opt.cosine_annealing_warm_restarts(0.1, T_0=0) + + with self.assertRaises(ValueError): + opt.cosine_annealing_warm_restarts(0.1, T_0=10, T_mult=0) + + with self.assertRaises(ValueError): + opt.cosine_annealing_warm_restarts(0.1, T_0=10, eta_min=-0.1) + + with self.assertRaises(ValueError): + opt.cosine_annealing_warm_restarts(0.01, T_0=10, eta_min=0.1) + def test_schedule_joiner(self): boundaries = [2, 3, 4] schedules = [lambda _: 3, lambda _: 4, lambda _: 5]