mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Configurable LR schedulers (#604)
* Initial config handler and test * Added means to run from CLI * Update lora config loading and tests * Constrain scheduler config (warmup and minimum LR) for each kind * Update reference to moved schedule_config module * Minor fix * Fix typos * Moved build_schedule and tests * nits in schedule config * flake * fix path --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -3,11 +3,32 @@ from typing import Dict
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .lora import LoRALinear
|
||||
|
||||
|
||||
def build_schedule(schedule_config: Dict):
|
||||
"""
|
||||
Build a learning rate schedule from the given config.
|
||||
"""
|
||||
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
|
||||
arguments = schedule_config["arguments"]
|
||||
initial_lr = arguments[0]
|
||||
bound_schedule_fn = schedule_fn(*arguments)
|
||||
if warmup_steps := schedule_config.get("warmup", 0):
|
||||
warmup_init = schedule_config.get("warmup_init", 0.0)
|
||||
warmup_fn = opt.schedulers.linear_schedule(
|
||||
warmup_init, initial_lr, warmup_steps
|
||||
)
|
||||
return opt.schedulers.join_schedules(
|
||||
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
|
||||
)
|
||||
else:
|
||||
return bound_schedule_fn
|
||||
|
||||
|
||||
def linear_to_lora_layers(
|
||||
model: nn.Module,
|
||||
num_lora_layers: int,
|
||||
|
Reference in New Issue
Block a user