diff --git a/python/mlx/optimizers/schedulers.py b/python/mlx/optimizers/schedulers.py index a8c0354f3..67e4e29cd 100644 --- a/python/mlx/optimizers/schedulers.py +++ b/python/mlx/optimizers/schedulers.py @@ -80,12 +80,12 @@ def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable: array(0.0999961, dtype=float32) """ - def scheduler(step): + def schedule(step): s = mx.minimum(step, decay_steps) decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s)) return end + decay * (init - end) - return scheduler + return schedule def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable: @@ -99,9 +99,9 @@ def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable that indicates when to transition between schedules. Example: - >>> warmup = optim.linear_schedule(0, 1e-1, steps=10) + >>> linear = optim.linear_schedule(0, 1e-1, steps=10) >>> cosine = optim.cosine_decay(1e-1, 200) - >>> lr_schedule = optim.join_schedules([warmup, cosine], [10]) + >>> lr_schedule = optim.join_schedules([linear, cosine], [10]) >>> optimizer = optim.Adam(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.0, dtype=float32) @@ -139,8 +139,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable: Example: - >>> warmup = optim.linear_schedule(0, 1e-1, 100) - >>> optimizer = optim.Adam(learning_rate=warmup) + >>> lr_schedule = optim.linear_schedule(0, 1e-1, 100) + >>> optimizer = optim.Adam(learning_rate=lr_schedule) >>> optimizer.learning_rate array(0.0, dtype=float32) >>> for _ in range(101): optimizer.update({}, {}) @@ -151,8 +151,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable: if steps < 1: raise ValueError(f"steps must be greater than 0, but got {steps}.") - def step_fn(step): + def schedule(step): step = mx.minimum(step, steps) return step * ((end - init) / steps) + init - return step_fn + return schedule