mirror of
https://github.com/ml-explore/mlx.git
synced 2025-06-24 17:31:16 +08:00
refactor: all use schedule (#1973)
This commit is contained in:
parent
0a9777aa5c
commit
3779150750
@ -80,12 +80,12 @@ def cosine_decay(init: float, decay_steps: int, end: float = 0.0) -> Callable:
|
|||||||
array(0.0999961, dtype=float32)
|
array(0.0999961, dtype=float32)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def scheduler(step):
|
def schedule(step):
|
||||||
s = mx.minimum(step, decay_steps)
|
s = mx.minimum(step, decay_steps)
|
||||||
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
decay = 0.5 * (1.0 + mx.cos((math.pi / decay_steps) * s))
|
||||||
return end + decay * (init - end)
|
return end + decay * (init - end)
|
||||||
|
|
||||||
return scheduler
|
return schedule
|
||||||
|
|
||||||
|
|
||||||
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
|
def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable:
|
||||||
@ -99,9 +99,9 @@ def join_schedules(schedules: List[Callable], boundaries: List[int]) -> Callable
|
|||||||
that indicates when to transition between schedules.
|
that indicates when to transition between schedules.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
>>> warmup = optim.linear_schedule(0, 1e-1, steps=10)
|
>>> linear = optim.linear_schedule(0, 1e-1, steps=10)
|
||||||
>>> cosine = optim.cosine_decay(1e-1, 200)
|
>>> cosine = optim.cosine_decay(1e-1, 200)
|
||||||
>>> lr_schedule = optim.join_schedules([warmup, cosine], [10])
|
>>> lr_schedule = optim.join_schedules([linear, cosine], [10])
|
||||||
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||||
>>> optimizer.learning_rate
|
>>> optimizer.learning_rate
|
||||||
array(0.0, dtype=float32)
|
array(0.0, dtype=float32)
|
||||||
@ -139,8 +139,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
|||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
>>> warmup = optim.linear_schedule(0, 1e-1, 100)
|
>>> lr_schedule = optim.linear_schedule(0, 1e-1, 100)
|
||||||
>>> optimizer = optim.Adam(learning_rate=warmup)
|
>>> optimizer = optim.Adam(learning_rate=lr_schedule)
|
||||||
>>> optimizer.learning_rate
|
>>> optimizer.learning_rate
|
||||||
array(0.0, dtype=float32)
|
array(0.0, dtype=float32)
|
||||||
>>> for _ in range(101): optimizer.update({}, {})
|
>>> for _ in range(101): optimizer.update({}, {})
|
||||||
@ -151,8 +151,8 @@ def linear_schedule(init: float, end: float, steps: int) -> Callable:
|
|||||||
if steps < 1:
|
if steps < 1:
|
||||||
raise ValueError(f"steps must be greater than 0, but got {steps}.")
|
raise ValueError(f"steps must be greater than 0, but got {steps}.")
|
||||||
|
|
||||||
def step_fn(step):
|
def schedule(step):
|
||||||
step = mx.minimum(step, steps)
|
step = mx.minimum(step, steps)
|
||||||
return step * ((end - init) / steps) + init
|
return step * ((end - init) / steps) + init
|
||||||
|
|
||||||
return step_fn
|
return schedule
|
||||||
|
Loading…
Reference in New Issue
Block a user