Add minimum for cosine decay function (#859)

* Add minimum for cosine decay function

* Update python/mlx/optimizers/schedulers.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Chime Ogbuji
2024-03-21 10:33:29 -04:00
committed by GitHub
parent a54f06b16f
commit f5a1582fe8
2 changed files with 8 additions and 2 deletions

View File

@@ -58,13 +58,14 @@ def step_decay(init: float, decay_rate: float, step_size: int) -> Callable:
return schedule
def cosine_decay(init: float, decay_steps: int) -> Callable:
def cosine_decay(init: float, decay_steps: int, minimum: float = 0.0) -> Callable:
r"""Make a cosine decay scheduler.
Args:
init (float): Initial value.
decay_steps (int): Number of steps to decay over. The decayed
value is constant for steps beyond ``decay_steps``.
minimum (float, optional): Minimal value to decay to. Default: ``0``.
Example:
@@ -82,7 +83,7 @@ def cosine_decay(init: float, decay_steps: int) -> Callable:
def scheduler(step):
s = mx.minimum(step, decay_steps)
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
return init * decay
return mx.maximum(init * decay, minimum)
return scheduler