mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 09:21:16 +08:00
refactor: all use schedule (#1973)
This commit is contained in:
parent
0a9777aa5c
commit
3779150750
@ -80,12 +80,12 @@ def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable:
|
||||
array(0.0999961, dtype=float32)
|
||||
"""
|
||||
|
||||
def scheduler(step):
|
||||
def schedule(step):
|
||||
s = mx.minimum(step, decay_steps)
|
||||
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
||||
return end + decay * (init - end)
|
||||
|
||||
return scheduler
|
||||
return schedule
|
||||
|
||||
|
||||
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
|
||||
@ -99,9 +99,9 @@ def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable
|
||||
that indicates when to transition between schedules.
|
||||
|
||||
Example:
|
||||
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
|
||||
>>> linear = optim.linear_schedule(0, 1e-1, steps=10)
|
||||
>>> cosine = optim.cosine_decay(1e-1, 200)
|
||||
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
|
||||
>>> lr_schedule = optim.join_schedules([linear, cosine], [10])
|
||||
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0, dtype=float32)
|
||||
@ -139,8 +139,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
||||
|
||||
Example:
|
||||
|
||||
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
|
||||
>>> optimizer = optim.Adam(learning_rate=warmup)
|
||||
>>> lr_schedule = optim.linear_schedule(0, 1e-1, 100)
|
||||
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||
>>> optimizer.learning_rate
|
||||
array(0.0, dtype=float32)
|
||||
>>> for _ in range(101): optimizer.update({}, {})
|
||||
@ -151,8 +151,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
||||
if steps < 1:
|
||||
raise ValueError(f"steps must be greater than 0, but got {steps}.")
|
||||
|
||||
def step_fn(step):
|
||||
def schedule(step):
|
||||
step = mx.minimum(step, steps)
|
||||
return step * ((end - init) / steps) + init
|
||||
|
||||
return step_fn
|
||||
return schedule
|
||||
|
Loading…
Reference in New Issue
Block a user