mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add CosineAnnealingWarmRestarts scheduler
This commit is contained in:
@@ -156,3 +156,60 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
|||||||
return step * ((end - init) / steps) + init
|
return step * ((end - init) / steps) + init
|
||||||
|
|
||||||
return schedule
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
|
def cosine_annealing_warm_restarts(
|
||||||
|
init: float, T_0: int, T_mult: int = 1, eta_min: float = 0.0
|
||||||
|
) -> Callable:
|
||||||
|
r"""Make a cosine annealing scheduler with warm restarts.
|
||||||
|
|
||||||
|
The learning rate anneals using a cosine schedule and resets periodically.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init (float): Initial learning rate.
|
||||||
|
T_0 (int): Number of steps for the first restart.
|
||||||
|
T_mult (int, optional): Factor to increase the period after each restart. Default: ``1``.
|
||||||
|
eta_min (float, optional): Minimum learning rate. Default: ``0.0``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> lr_schedule = optim.cosine_annealing_warm_restarts(1e-1, T_0=10)
|
||||||
|
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.1, dtype=float32)
|
||||||
|
>>>
|
||||||
|
>>> for _ in range(11): optimizer.update({}, {})
|
||||||
|
...
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.1, dtype=float32)
|
||||||
|
"""
|
||||||
|
if T_0 < 1:
|
||||||
|
raise ValueError(f"T_0 must be at least 1, got {T_0}")
|
||||||
|
if T_mult < 1:
|
||||||
|
raise ValueError(f"T_mult must be at least 1, got {T_mult}")
|
||||||
|
if eta_min < 0:
|
||||||
|
raise ValueError(f"eta_min must be non-negative, got {eta_min}")
|
||||||
|
if init < eta_min:
|
||||||
|
raise ValueError(f"init must be >= eta_min, got init={init}, eta_min={eta_min}")
|
||||||
|
|
||||||
|
def schedule(step):
|
||||||
|
if isinstance(step, mx.array):
|
||||||
|
step_val = step.item() if step.size == 1 else step
|
||||||
|
else:
|
||||||
|
step_val = step
|
||||||
|
|
||||||
|
if T_mult == 1:
|
||||||
|
T_cur = step_val % T_0
|
||||||
|
T_i = T_0
|
||||||
|
else:
|
||||||
|
if step_val >= T_0:
|
||||||
|
n = int(math.log((step_val / T_0 * (T_mult - 1) + 1), T_mult))
|
||||||
|
T_cur = step_val - T_0 * (T_mult**n - 1) / (T_mult - 1)
|
||||||
|
T_i = T_0 * T_mult**n
|
||||||
|
else:
|
||||||
|
T_i = T_0
|
||||||
|
T_cur = step_val
|
||||||
|
|
||||||
|
cos_inner = math.pi * T_cur / T_i
|
||||||
|
return eta_min + (init - eta_min) * 0.5 * (1.0 + mx.cos(cos_inner))
|
||||||
|
|
||||||
|
return schedule
|
||||||
|
|||||||
@@ -446,6 +446,56 @@ class TestSchedulers(mlx_tests.MLXTestCase):
|
|||||||
lr = lr_schedule(20)
|
lr = lr_schedule(20)
|
||||||
self.assertEqual(lr, expected_end_lr)
|
self.assertEqual(lr, expected_end_lr)
|
||||||
|
|
||||||
|
def test_cosine_annealing_warm_restarts(self):
|
||||||
|
# Test with T_mult=1 (equal periods)
|
||||||
|
lr_schedule = opt.cosine_annealing_warm_restarts(
|
||||||
|
0.1, T_0=10, T_mult=1, eta_min=0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test initial value
|
||||||
|
lr = lr_schedule(0)
|
||||||
|
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||||
|
|
||||||
|
# Test mid-cycle (should be minimum for T_mult=1)
|
||||||
|
lr = lr_schedule(5)
|
||||||
|
expected_lr = 0.0 + (0.1 - 0.0) * 0.5 * (1.0 + math.cos(math.pi * 5 / 10))
|
||||||
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||||
|
|
||||||
|
# Test restart (after T_0 steps)
|
||||||
|
lr = lr_schedule(10)
|
||||||
|
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||||
|
|
||||||
|
# Test with T_mult=2 (doubling periods)
|
||||||
|
lr_schedule = opt.cosine_annealing_warm_restarts(
|
||||||
|
0.1, T_0=5, T_mult=2, eta_min=0.01
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test first cycle restart
|
||||||
|
lr = lr_schedule(5)
|
||||||
|
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||||
|
|
||||||
|
# Test second cycle (should be 10 steps long, restart at step 15)
|
||||||
|
lr = lr_schedule(15)
|
||||||
|
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||||
|
|
||||||
|
# Test with eta_min
|
||||||
|
lr = lr_schedule(10) # Mid of second cycle
|
||||||
|
expected_lr = 0.01 + (0.1 - 0.01) * 0.5 * (1.0 + math.cos(math.pi * 5 / 10))
|
||||||
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||||
|
|
||||||
|
# Test error handling
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
opt.cosine_annealing_warm_restarts(0.1, T_0=0)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
opt.cosine_annealing_warm_restarts(0.1, T_0=10, T_mult=0)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
opt.cosine_annealing_warm_restarts(0.1, T_0=10, eta_min=-0.1)
|
||||||
|
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
opt.cosine_annealing_warm_restarts(0.01, T_0=10, eta_min=0.1)
|
||||||
|
|
||||||
def test_schedule_joiner(self):
|
def test_schedule_joiner(self):
|
||||||
boundaries = [2, 3, 4]
|
boundaries = [2, 3, 4]
|
||||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||||
|
|||||||
Reference in New Issue
Block a user