mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Configurable LR schedulers (#604)
* Initial config handler and test * Added means to run from CLI * Update lora config loading and tests * Constrain scheduler config (warmup and minimum LR) for each kind * Update reference to moved schedule_config module * Minor fix * Fix typos * Moved build_schedule and tests * nits in schedule config * flake * fix path --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -15,7 +15,7 @@ from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import linear_to_lora_layers
|
||||
from .tuner.utils import build_schedule, linear_to_lora_layers
|
||||
from .utils import load
|
||||
|
||||
yaml_loader = yaml.SafeLoader
|
||||
@@ -53,6 +53,7 @@ CONFIG_DEFAULTS = {
|
||||
"test": False,
|
||||
"test_batches": 500,
|
||||
"max_seq_length": 2048,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
}
|
||||
|
||||
@@ -199,7 +200,13 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(learning_rate=args.learning_rate)
|
||||
opt = optim.Adam(
|
||||
learning_rate=(
|
||||
build_schedule(args.lr_schedule)
|
||||
if args.lr_schedule
|
||||
else args.learning_rate
|
||||
)
|
||||
)
|
||||
# Train model
|
||||
train(
|
||||
model=model,
|
||||
|
Reference in New Issue
Block a user