Configurable LR schedulers (#604)

* Initial config handler and test

* Added means to run from CLI

* Update lora config loading and tests

* Constrain scheduler config (warmup and minimum LR) for each kind

* Update reference to moved schedule_config module

* Minor fix

* Fix typos

* Moved build_schedule and tests

* nits in schedule config

* flake

* fix path

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-03-29 16:41:10 -04:00 committed by GitHub
parent b80adbcc3e
commit f6283ef7ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 93 additions and 12 deletions

View File

@ -16,7 +16,7 @@ lora_layers: 16
batch_size: 4 batch_size: 4
# Iterations to train for. # Iterations to train for.
iters: 100 iters: 1000
# Number of validation batches, -1 uses the entire validation set. # Number of validation batches, -1 uses the entire validation set.
val_batches: 25 val_batches: 25
@ -43,7 +43,7 @@ save_every: 100
test: false test: false
# Number of test set batches, -1 uses the entire test set. # Number of test set batches, -1 uses the entire test set.
test_batches: 500 test_batches: 100
# Maximum sequence length. # Maximum sequence length.
max_seq_length: 2048 max_seq_length: 2048
@ -60,3 +60,10 @@ lora_parameters:
alpha: 16.0 alpha: 16.0
scale: 10.0 scale: 10.0
dropout: 0.0 dropout: 0.0
# Schedule can only be specified in a config file, uncomment to use.
#lr_schedule:
# name: cosine_decay
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler

View File

@ -15,7 +15,7 @@ from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers from .tuner.utils import build_schedule, linear_to_lora_layers
from .utils import load from .utils import load
yaml_loader = yaml.SafeLoader yaml_loader = yaml.SafeLoader
@ -53,6 +53,7 @@ CONFIG_DEFAULTS = {
"test": False, "test": False,
"test_batches": 500, "test_batches": 500,
"max_seq_length": 2048, "max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
} }
@ -199,7 +200,13 @@ def run(args, training_callback: TrainingCallback = None):
) )
model.train() model.train()
opt = optim.Adam(learning_rate=args.learning_rate) opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule)
if args.lr_schedule
else args.learning_rate
)
)
# Train model # Train model
train( train(
model=model, model=model,

View File

@ -1,6 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from typing import Optional, Tuple
from typing import Dict, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass from dataclasses import dataclass
from sys import exit from sys import exit
from typing import Dict, Optional, Tuple, Union from typing import Optional, Tuple
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, List, Optional, Tuple, Union
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -183,11 +183,7 @@ class PlamoModel(nn.Module):
mask = mask.astype(self.embed_tokens.weight.dtype) mask = mask.astype(self.embed_tokens.weight.dtype)
if cache is None: if cache is None:
past_key_values_length = 0
cache = [None for _ in range(len(self.layers.layers))] cache = [None for _ in range(len(self.layers.layers))]
else:
if cache[0] is not None:
past_key_values_length = cache[0][0].shape[2]
for e, layer in enumerate(self.layers.layers): for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, cache[e]) h, c = layer(h, mask, cache[e])

View File

@ -3,11 +3,32 @@ from typing import Dict
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from .lora import LoRALinear from .lora import LoRALinear
def build_schedule(schedule_config: Dict):
"""
Build a learning rate schedule from the given config.
"""
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
arguments = schedule_config["arguments"]
initial_lr = arguments[0]
bound_schedule_fn = schedule_fn(*arguments)
if warmup_steps := schedule_config.get("warmup", 0):
warmup_init = schedule_config.get("warmup_init", 0.0)
warmup_fn = opt.schedulers.linear_schedule(
warmup_init, initial_lr, warmup_steps
)
return opt.schedulers.join_schedules(
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
)
else:
return bound_schedule_fn
def linear_to_lora_layers( def linear_to_lora_layers(
model: nn.Module, model: nn.Module,
num_lora_layers: int, num_lora_layers: int,

View File

@ -1,14 +1,18 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import math
import sys import sys
import unittest import unittest
from io import StringIO from io import StringIO
from unittest.mock import MagicMock from unittest.mock import MagicMock
import mlx.nn as nn import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from mlx_lm import lora, tuner from mlx_lm import lora, tuner
from mlx_lm.lora import yaml_loader
from mlx_lm.tuner.lora import LoRALinear from mlx_lm.tuner.lora import LoRALinear
from mlx_lm.tuner.utils import build_schedule
class TestLora(unittest.TestCase): class TestLora(unittest.TestCase):
@ -120,5 +124,52 @@ class TestLora(unittest.TestCase):
self.assertEqual(self.capturedOutput.getvalue(), expected_output) self.assertEqual(self.capturedOutput.getvalue(), expected_output)
class TestScheduleConfig(unittest.TestCase):
def test_join(self):
config = {"name": "cosine_decay", "warmup": 100, "arguments": [1e-5, 100]}
cos_with_warmup = build_schedule(config)
self.assertIsNotNone(cos_with_warmup)
self.assertEqual(cos_with_warmup(0), 0.0)
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
optimizer = opt.Adam(learning_rate=cos_with_warmup)
for _ in range(100):
optimizer.update({}, {})
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
for _ in range(100):
optimizer.update({}, {})
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
def test_single_schedule(self):
config = {
"name": "cosine_decay",
"arguments": [0.1, 10],
}
lr_schedule = build_schedule(config)
lr = lr_schedule(4)
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_non_zero_warmup(self):
config = {
"name": "cosine_decay",
"warmup": 10,
"warmup_init": 1e-6,
"arguments": [1e-5, 20],
}
lr_schedule = build_schedule(config)
lr = lr_schedule(0)
self.assertAlmostEqual(lr, 1e-6, delta=1e-7)
def test_malformed_config(self):
config = {"warmup": 100}
self.assertRaises(KeyError, build_schedule, config)
config = {"cosine_decay": None}
self.assertRaises(KeyError, build_schedule, config)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()