mirror of
https://github.com/ml-explore/mlx.git
synced 2025-07-17 06:21:12 +08:00
Treate 'minimum' differently in cosine decay (#1138)
This commit is contained in:
parent
0a558577bf
commit
7e5674d8be
@ -58,14 +58,14 @@ def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
|
||||
return schedule
|
||||
|
||||
|
||||
def cosine_decay(init: float, decay_steps: int, minimum: float = 0.0) -> Callable:
|
||||
def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable:
|
||||
r"""Make a cosine decay scheduler.
|
||||
|
||||
Args:
|
||||
init (float): Initial value.
|
||||
decay_steps (int): Number of steps to decay over. The decayed
|
||||
value is constant for steps beyond ``decay_steps``.
|
||||
minimum (float, optional): Minimal value to decay to. Default: ``0``.
|
||||
end (float, optional): Final value to decay to. Default: ``0``.
|
||||
|
||||
Example:
|
||||
|
||||
@ -83,7 +83,7 @@ def cosine_decay(init: float, decay_steps: int, minimum: float = 0.0) -> Callabl
|
||||
def scheduler(step):
|
||||
s = mx.minimum(step, decay_steps)
|
||||
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
||||
return mx.maximum(init * decay, minimum)
|
||||
return end + decay * (init - end)
|
||||
|
||||
return scheduler
|
||||
|
||||
|
@ -329,9 +329,11 @@ class TestSchedulers(unittest.TestCase):
|
||||
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||
|
||||
lr_schedule = opt.cosine_decay(0.1, 10, 0.05)
|
||||
lr = lr_schedule(9)
|
||||
expected_end_lr = 0.05
|
||||
self.assertGreater(lr, expected_end_lr)
|
||||
lr = lr_schedule(20)
|
||||
expected_lr = 0.05
|
||||
self.assertEqual(lr, expected_lr)
|
||||
self.assertEqual(lr, expected_end_lr)
|
||||
|
||||
def test_schedule_joiner(self):
|
||||
boundaries = [2, 3, 4]
|
||||
|
Loading…
Reference in New Issue
Block a user