Configurable LR schedulers (#604)

* Initial config handler and test

* Added means to run from CLI

* Update lora config loading and tests

* Constrain scheduler config (warmup and minimum LR) for each kind

* Update reference to moved schedule_config module

* Minor fix

* Fix typos

* Moved build_schedule and tests

* nits in schedule config

* flake

* fix path

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-03-29 16:41:10 -04:00
committed by GitHub
parent b80adbcc3e
commit f6283ef7ce
7 changed files with 93 additions and 12 deletions

View File

@@ -16,7 +16,7 @@ lora_layers: 16
batch_size: 4
# Iterations to train for.
iters: 100
iters: 1000
# Number of validation batches, -1 uses the entire validation set.
val_batches: 25
@@ -43,7 +43,7 @@ save_every: 100
test: false
# Number of test set batches, -1 uses the entire test set.
test_batches: 500
test_batches: 100
# Maximum sequence length.
max_seq_length: 2048
@@ -60,3 +60,10 @@ lora_parameters:
alpha: 16.0
scale: 10.0
dropout: 0.0
# Schedule can only be specified in a config file, uncomment to use.
#lr_schedule:
# name: cosine_decay
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler

View File

@@ -15,7 +15,7 @@ from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers
from .tuner.utils import build_schedule, linear_to_lora_layers
from .utils import load
yaml_loader = yaml.SafeLoader
@@ -53,6 +53,7 @@ CONFIG_DEFAULTS = {
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}
@@ -199,7 +200,13 @@ def run(args, training_callback: TrainingCallback = None):
)
model.train()
opt = optim.Adam(learning_rate=args.learning_rate)
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule)
if args.lr_schedule
else args.learning_rate
)
)
# Train model
train(
model=model,

View File

@@ -1,6 +1,5 @@
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn

View File

@@ -1,6 +1,6 @@
from dataclasses import dataclass
from sys import exit
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn

View File

@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@@ -183,11 +183,7 @@ class PlamoModel(nn.Module):
mask = mask.astype(self.embed_tokens.weight.dtype)
if cache is None:
past_key_values_length = 0
cache = [None for _ in range(len(self.layers.layers))]
else:
if cache[0] is not None:
past_key_values_length = cache[0][0].shape[2]
for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, cache[e])

View File

@@ -3,11 +3,32 @@ from typing import Dict
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_unflatten
from .lora import LoRALinear
def build_schedule(schedule_config: Dict):
"""
Build a learning rate schedule from the given config.
"""
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
arguments = schedule_config["arguments"]
initial_lr = arguments[0]
bound_schedule_fn = schedule_fn(*arguments)
if warmup_steps := schedule_config.get("warmup", 0):
warmup_init = schedule_config.get("warmup_init", 0.0)
warmup_fn = opt.schedulers.linear_schedule(
warmup_init, initial_lr, warmup_steps
)
return opt.schedulers.join_schedules(
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
)
else:
return bound_schedule_fn
def linear_to_lora_layers(
model: nn.Module,
num_lora_layers: int,