mirror of
https://github.com/ml-explore/mlx.git
synced 2025-12-16 01:49:05 +08:00
Add cyclic_lr scheduler
This commit is contained in:
@@ -9,6 +9,7 @@ Schedulers
|
|||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
cosine_decay
|
cosine_decay
|
||||||
|
cyclic_lr
|
||||||
exponential_decay
|
exponential_decay
|
||||||
join_schedules
|
join_schedules
|
||||||
linear_schedule
|
linear_schedule
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Callable, List
|
from typing import Callable, List, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@@ -156,3 +156,64 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
|||||||
return step * ((end - init) / steps) + init
|
return step * ((end - init) / steps) + init
|
||||||
|
|
||||||
return schedule
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
|
def cyclic_lr(
|
||||||
|
base_lr: float,
|
||||||
|
max_lr: float,
|
||||||
|
step_size_up: int = 2000,
|
||||||
|
step_size_down: Optional[int] = None,
|
||||||
|
mode: str = "triangular",
|
||||||
|
gamma: float = 1.0,
|
||||||
|
) -> Callable:
|
||||||
|
r"""Make a cyclic learning rate scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_lr (float): Lower learning rate boundary.
|
||||||
|
max_lr (float): Upper learning rate boundary.
|
||||||
|
step_size_up (int): Number of steps in the increasing half of a cycle. Default: ``2000``.
|
||||||
|
step_size_down (int, optional): Number of steps in the decreasing half.
|
||||||
|
If ``None``, equals ``step_size_up``. Default: ``None``.
|
||||||
|
mode (str): One of ``"triangular"``, ``"triangular2"``, ``"exp_range"``. Default: ``"triangular"``.
|
||||||
|
gamma (float): Scaling factor for ``"exp_range"`` mode. Default: ``1.0``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> lr_schedule = optim.cyclic_lr(0.001, 0.1, step_size_up=10)
|
||||||
|
>>> optimizer = optim.SGD(learning_rate=lr_schedule)
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.001, dtype=float32)
|
||||||
|
>>>
|
||||||
|
>>> for _ in range(5): optimizer.update({}, {})
|
||||||
|
...
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.0505, dtype=float32)
|
||||||
|
"""
|
||||||
|
step_size_down = step_size_down if step_size_down is not None else step_size_up
|
||||||
|
total_size = step_size_up + step_size_down
|
||||||
|
step_ratio = step_size_up / total_size
|
||||||
|
|
||||||
|
def schedule(step):
|
||||||
|
if isinstance(step, mx.array):
|
||||||
|
step_val = step.item() if step.size == 1 else step
|
||||||
|
else:
|
||||||
|
step_val = step
|
||||||
|
|
||||||
|
cycle = math.floor(1 + step_val / total_size)
|
||||||
|
x = 1.0 + step_val / total_size - cycle
|
||||||
|
|
||||||
|
if x <= step_ratio:
|
||||||
|
scale_factor = x / step_ratio
|
||||||
|
else:
|
||||||
|
scale_factor = (x - 1) / (step_ratio - 1)
|
||||||
|
|
||||||
|
if mode == "triangular":
|
||||||
|
scale_fn_val = 1.0
|
||||||
|
elif mode == "triangular2":
|
||||||
|
scale_fn_val = 1 / (2.0 ** (cycle - 1))
|
||||||
|
else: # exp_range
|
||||||
|
scale_fn_val = gamma ** (cycle - 1)
|
||||||
|
|
||||||
|
base_height = (max_lr - base_lr) * scale_factor
|
||||||
|
return base_lr + base_height * scale_fn_val
|
||||||
|
|
||||||
|
return schedule
|
||||||
|
|||||||
@@ -446,6 +446,23 @@ class TestSchedulers(mlx_tests.MLXTestCase):
|
|||||||
lr = lr_schedule(20)
|
lr = lr_schedule(20)
|
||||||
self.assertEqual(lr, expected_end_lr)
|
self.assertEqual(lr, expected_end_lr)
|
||||||
|
|
||||||
|
def test_cyclic_lr(self):
|
||||||
|
lr_schedule = opt.cyclic_lr(0.001, 0.1, step_size_up=10)
|
||||||
|
|
||||||
|
lr = lr_schedule(0)
|
||||||
|
self.assertAlmostEqual(lr, 0.001, delta=1e-7)
|
||||||
|
|
||||||
|
lr = lr_schedule(10)
|
||||||
|
self.assertAlmostEqual(lr, 0.1, delta=1e-7)
|
||||||
|
|
||||||
|
lr = lr_schedule(20)
|
||||||
|
self.assertAlmostEqual(lr, 0.001, delta=1e-7)
|
||||||
|
|
||||||
|
lr_schedule = opt.cyclic_lr(0.001, 0.1, step_size_up=5, mode="triangular2")
|
||||||
|
lr = lr_schedule(15)
|
||||||
|
expected_lr = 0.001 + (0.1 - 0.001) * 0.5
|
||||||
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-6)
|
||||||
|
|
||||||
def test_schedule_joiner(self):
|
def test_schedule_joiner(self):
|
||||||
boundaries = [2, 3, 4]
|
boundaries = [2, 3, 4]
|
||||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||||
|
|||||||
Reference in New Issue
Block a user