refactor: all use schedule (#1973)

This commit is contained in:
Chunyang Wen 2025-03-20 02:24:04 +08:00 committed by GitHub
parent 0a9777aa5c
commit 3779150750
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -80,12 +80,12 @@ def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable:
array(0.0999961, dtype=float32)
"""
def scheduler(step):
def schedule(step):
s = mx.minimum(step, decay_steps)
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
return end + decay * (init - end)
return scheduler
return schedule
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
@ -99,9 +99,9 @@ def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable
that indicates when to transition between schedules.
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
>>> linear = optim.linear_schedule(0, 1e-1, steps=10)
>>> cosine = optim.cosine_decay(1e-1, 200)
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
>>> lr_schedule = optim.join_schedules([linear, cosine], [10])
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
@ -139,8 +139,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
Example:
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=warmup)
>>> lr_schedule = optim.linear_schedule(0, 1e-1, 100)
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
>>> optimizer.learning_rate
array(0.0, dtype=float32)
>>> for _ in range(101): optimizer.update({}, {})
@ -151,8 +151,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
if steps < 1:
raise ValueError(f"steps must be greater than 0, but got {steps}.")
def step_fn(step):
def schedule(step):
step = mx.minimum(step, steps)
return step * ((end - init) / steps) + init
return step_fn
return schedule