Configurable LR schedulers (#604)

* Initial config handler and test

* Added means to run from CLI

* Update lora config loading and tests

* Constrain scheduler config (warmup and minimum LR) for each kind

* Update reference to moved schedule_config module

* Minor fix

* Fix typos

* Moved build_schedule and tests

* nits in schedule config

* flake

* fix path

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-03-29 16:41:10 -04:00
committed by GitHub
parent b80adbcc3e
commit f6283ef7ce
7 changed files with 93 additions and 12 deletions

View File

@@ -3,11 +3,32 @@ from typing import Dict
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_unflatten
from .lora import LoRALinear
def build_schedule(schedule_config: Dict):
"""
Build a learning rate schedule from the given config.
"""
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
arguments = schedule_config["arguments"]
initial_lr = arguments[0]
bound_schedule_fn = schedule_fn(*arguments)
if warmup_steps := schedule_config.get("warmup", 0):
warmup_init = schedule_config.get("warmup_init", 0.0)
warmup_fn = opt.schedulers.linear_schedule(
warmup_init, initial_lr, warmup_steps
)
return opt.schedulers.join_schedules(
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
)
else:
return bound_schedule_fn
def linear_to_lora_layers(
model: nn.Module,
num_lora_layers: int,