mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-21 08:41:13 +08:00
Add linear warmup and schedule joining for use with existing schedules (#721)
* Add linear warmup to schedules for use with existing schedules * Changed parameters for simplicity of most common case (0 initial value) * Added ScheduleJoiner and updated documentation * ScheduleJoiner -> join_schedules (ala optax #) * black compliance * Different evaluation of schedules * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
e6418781ab
commit
3b661b7394
@ -8,6 +8,8 @@ Schedulers
|
|||||||
.. autosummary::
|
.. autosummary::
|
||||||
:toctree: _autosummary
|
:toctree: _autosummary
|
||||||
|
|
||||||
step_decay
|
|
||||||
exponential_decay
|
|
||||||
cosine_decay
|
cosine_decay
|
||||||
|
exponential_decay
|
||||||
|
join_schedules
|
||||||
|
linear_schedule
|
||||||
|
step_decay
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
|
|
||||||
def exponential_decay(init: float, decay_rate: float):
|
def exponential_decay(init: float, decay_rate: float) -> Callable:
|
||||||
r"""Make an exponential decay scheduler.
|
r"""Make an exponential decay scheduler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -30,7 +31,7 @@ def exponential_decay(init: float, decay_rate: float):
|
|||||||
return schedule
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
def step_decay(init: float, decay_rate: float, step_size: int):
|
def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
|
||||||
r"""Make a step decay scheduler.
|
r"""Make a step decay scheduler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -57,7 +58,7 @@ def step_decay(init: float, decay_rate: float, step_size: int):
|
|||||||
return schedule
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
def cosine_decay(init: float, decay_steps: int):
|
def cosine_decay(init: float, decay_steps: int) -> Callable:
|
||||||
r"""Make a cosine decay scheduler.
|
r"""Make a cosine decay scheduler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -84,3 +85,73 @@ def cosine_decay(init: float, decay_steps: int):
|
|||||||
return init * decay
|
return init * decay
|
||||||
|
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
|
||||||
|
r"""Join multiple schedules to create a new schedule.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
schedules (list(Callable)): A list of schedules. Schedule :math:`i+1`
|
||||||
|
receives a step count indicating the number of steps since
|
||||||
|
the :math:`i`-th boundary.
|
||||||
|
boundaries (list(int)): A list of integers of length ``len(schedules) - 1``
|
||||||
|
that indicates when to transition between schedules.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
|
||||||
|
>>> cosine = optim.cosine_decay(1e-1, 200)
|
||||||
|
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
|
||||||
|
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.0, dtype=float32)
|
||||||
|
>>> for _ in range(12): optimizer.update({}, {})
|
||||||
|
...
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.0999938, dtype=float32)
|
||||||
|
"""
|
||||||
|
if len(schedules) == 0:
|
||||||
|
raise ValueError("Must provide at least 1 schedule to join.")
|
||||||
|
|
||||||
|
if len(schedules) != len(boundaries) + 1:
|
||||||
|
raise ValueError(
|
||||||
|
f"Received {len(boundaries)} boundaries but "
|
||||||
|
f"expected {len(schedules) - 1}."
|
||||||
|
)
|
||||||
|
|
||||||
|
def schedule(step):
|
||||||
|
output = schedules[0](step)
|
||||||
|
for boundary, schedule in zip(boundaries, schedules[1:]):
|
||||||
|
output = mx.where(step < boundary, output, schedule(step - boundary))
|
||||||
|
return output
|
||||||
|
|
||||||
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
|
def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
||||||
|
r"""Make a linear scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
init (float): Initial value.
|
||||||
|
end (float): Final value.
|
||||||
|
steps (int): Number of steps to apply the schedule over. The value is
|
||||||
|
``end`` for any steps beyond ``steps``.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
|
||||||
|
>>> optimizer = optim.Adam(learning_rate=warmup)
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.0, dtype=float32)
|
||||||
|
>>> for _ in range(101): optimizer.update({}, {})
|
||||||
|
...
|
||||||
|
>>> optimizer.learning_rate
|
||||||
|
array(0.1, dtype=float32)
|
||||||
|
"""
|
||||||
|
if steps < 1:
|
||||||
|
raise ValueError(f"steps must be greater than 0, but got {steps}.")
|
||||||
|
|
||||||
|
def step_fn(step):
|
||||||
|
step = mx.minimum(step, steps)
|
||||||
|
return step * ((end - init) / steps) + init
|
||||||
|
|
||||||
|
return step_fn
|
||||||
|
@ -328,6 +328,37 @@ class TestSchedulers(unittest.TestCase):
|
|||||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||||
|
|
||||||
|
def test_schedule_joiner(self):
|
||||||
|
boundaries = [2, 3, 4]
|
||||||
|
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
opt.schedulers.join_schedules(schedules, boundaries)
|
||||||
|
boundaries = [2, 4]
|
||||||
|
schedule = opt.schedulers.join_schedules(schedules, boundaries)
|
||||||
|
self.assertEqual(schedule(0).item(), 3)
|
||||||
|
self.assertEqual(schedule(1).item(), 3)
|
||||||
|
self.assertEqual(schedule(2).item(), 4)
|
||||||
|
self.assertEqual(schedule(3).item(), 4)
|
||||||
|
self.assertEqual(schedule(5).item(), 5)
|
||||||
|
self.assertEqual(schedule(7).item(), 5)
|
||||||
|
|
||||||
|
def test_linear_warmup_with_cosine_decay(self):
|
||||||
|
warmup_schedule = opt.schedulers.linear_schedule(0.0, 1e-5, 100)
|
||||||
|
cosine_schedule = opt.schedulers.cosine_decay(1e-5, 100)
|
||||||
|
cos_with_warmup = opt.schedulers.join_schedules(
|
||||||
|
[warmup_schedule, cosine_schedule], [101]
|
||||||
|
)
|
||||||
|
self.assertEqual(cos_with_warmup(0), 0.0)
|
||||||
|
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
|
||||||
|
optimizer = opt.Adam(learning_rate=cos_with_warmup)
|
||||||
|
for _ in range(100):
|
||||||
|
optimizer.update({}, {})
|
||||||
|
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
|
||||||
|
for _ in range(100):
|
||||||
|
optimizer.update({}, {})
|
||||||
|
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
|
||||||
|
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
|
||||||
|
|
||||||
def test_compile_with_schedule(self):
|
def test_compile_with_schedule(self):
|
||||||
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
lr_schedule = opt.exponential_decay(1e-1, 0.9)
|
||||||
optimizer = opt.SGD(learning_rate=lr_schedule)
|
optimizer = opt.SGD(learning_rate=lr_schedule)
|
||||||
|
Loading…
Reference in New Issue
Block a user