mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-30 06:31:21 +08:00
Add minimum for cosine decay function (#859)
* Add minimum for cosine decay function * Update python/mlx/optimizers/schedulers.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
a54f06b16f
commit
f5a1582fe8
@ -58,13 +58,14 @@ def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
|
|||||||
return schedule
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
def cosine_decay(init: float, decay_steps: int) -> Callable:
|
def cosine_decay(init: float, decay_steps: int, minimum: float = 0.0) -> Callable:
|
||||||
r"""Make a cosine decay scheduler.
|
r"""Make a cosine decay scheduler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
init (float): Initial value.
|
init (float): Initial value.
|
||||||
decay_steps (int): Number of steps to decay over. The decayed
|
decay_steps (int): Number of steps to decay over. The decayed
|
||||||
value is constant for steps beyond ``decay_steps``.
|
value is constant for steps beyond ``decay_steps``.
|
||||||
|
minimum (float, optional): Minimal value to decay to. Default: ``0``.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
@ -82,7 +83,7 @@ def cosine_decay(init: float, decay_steps: int) -> Callable:
|
|||||||
def scheduler(step):
|
def scheduler(step):
|
||||||
s = mx.minimum(step, decay_steps)
|
s = mx.minimum(step, decay_steps)
|
||||||
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
||||||
return init * decay
|
return mx.maximum(init * decay, minimum)
|
||||||
|
|
||||||
return scheduler
|
return scheduler
|
||||||
|
|
||||||
|
@ -328,6 +328,11 @@ class TestSchedulers(unittest.TestCase):
|
|||||||
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||||
|
|
||||||
|
lr_schedule = opt.cosine_decay(0.1, 10, 0.05)
|
||||||
|
lr = lr_schedule(20)
|
||||||
|
expected_lr = 0.05
|
||||||
|
self.assertEqual(lr, expected_lr)
|
||||||
|
|
||||||
def test_schedule_joiner(self):
|
def test_schedule_joiner(self):
|
||||||
boundaries = [2, 3, 4]
|
boundaries = [2, 3, 4]
|
||||||
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
schedules = [lambda _: 3, lambda _: 4, lambda _: 5]
|
||||||
|
Loading…
Reference in New Issue
Block a user