Configurable LR schedulers (#604)

* Initial config handler and test

* Added means to run from CLI

* Update lora config loading and tests

* Constrain scheduler config (warmup and minimum LR) for each kind

* Update reference to moved schedule_config module

* Minor fix

* Fix typos

* Moved build_schedule and tests

* nits in schedule config

* flake

* fix path

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-03-29 16:41:10 -04:00
committed by GitHub
parent b80adbcc3e
commit f6283ef7ce
7 changed files with 93 additions and 12 deletions

View File

@@ -15,7 +15,7 @@ from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers
from .tuner.utils import build_schedule, linear_to_lora_layers
from .utils import load
yaml_loader = yaml.SafeLoader
@@ -53,6 +53,7 @@ CONFIG_DEFAULTS = {
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}
@@ -199,7 +200,13 @@ def run(args, training_callback: TrainingCallback = None):
)
model.train()
opt = optim.Adam(learning_rate=args.learning_rate)
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule)
if args.lr_schedule
else args.learning_rate
)
)
# Train model
train(
model=model,