Configurable LR schedulers (#604)

* Initial config handler and test

* Added means to run from CLI

* Update lora config loading and tests

* Constrain scheduler config (warmup and minimum LR) for each kind

* Update reference to moved schedule_config module

* Minor fix

* Fix typos

* Moved build_schedule and tests

* nits in schedule config

* flake

* fix path

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-03-29 16:41:10 -04:00 committed by GitHub
parent b80adbcc3e
commit f6283ef7ce
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 93 additions and 12 deletions

View File

@ -16,7 +16,7 @@ lora_layers: 16
batch_size: 4
# Iterations to train for.
iters: 100
iters: 1000
# Number of validation batches, -1 uses the entire validation set.
val_batches: 25
@ -43,7 +43,7 @@ save_every: 100
test: false
# Number of test set batches, -1 uses the entire test set.
test_batches: 500
test_batches: 100
# Maximum sequence length.
max_seq_length: 2048
@ -60,3 +60,10 @@ lora_parameters:
alpha: 16.0
scale: 10.0
dropout: 0.0
# Schedule can only be specified in a config file, uncomment to use.
#lr_schedule:
# name: cosine_decay
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler

View File

@ -15,7 +15,7 @@ from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers
from .tuner.utils import build_schedule, linear_to_lora_layers
from .utils import load
yaml_loader = yaml.SafeLoader
@ -53,6 +53,7 @@ CONFIG_DEFAULTS = {
"test": False,
"test_batches": 500,
"max_seq_length": 2048,
"lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
}
@ -199,7 +200,13 @@ def run(args, training_callback: TrainingCallback = None):
)
model.train()
opt = optim.Adam(learning_rate=args.learning_rate)
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule)
if args.lr_schedule
else args.learning_rate
)
)
# Train model
train(
model=model,

View File

@ -1,6 +1,5 @@
from dataclasses import dataclass
from functools import partial
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn

View File

@ -1,6 +1,6 @@
from dataclasses import dataclass
from sys import exit
from typing import Dict, Optional, Tuple, Union
from typing import Optional, Tuple
import mlx.core as mx
import mlx.nn as nn

View File

@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, List, Optional, Tuple, Union
import mlx.core as mx
import mlx.nn as nn
@ -183,11 +183,7 @@ class PlamoModel(nn.Module):
mask = mask.astype(self.embed_tokens.weight.dtype)
if cache is None:
past_key_values_length = 0
cache = [None for _ in range(len(self.layers.layers))]
else:
if cache[0] is not None:
past_key_values_length = cache[0][0].shape[2]
for e, layer in enumerate(self.layers.layers):
h, c = layer(h, mask, cache[e])

View File

@ -3,11 +3,32 @@ from typing import Dict
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_unflatten
from .lora import LoRALinear
def build_schedule(schedule_config: Dict):
"""
Build a learning rate schedule from the given config.
"""
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
arguments = schedule_config["arguments"]
initial_lr = arguments[0]
bound_schedule_fn = schedule_fn(*arguments)
if warmup_steps := schedule_config.get("warmup", 0):
warmup_init = schedule_config.get("warmup_init", 0.0)
warmup_fn = opt.schedulers.linear_schedule(
warmup_init, initial_lr, warmup_steps
)
return opt.schedulers.join_schedules(
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
)
else:
return bound_schedule_fn
def linear_to_lora_layers(
model: nn.Module,
num_lora_layers: int,

View File

@ -1,14 +1,18 @@
# Copyright © 2024 Apple Inc.
import math
import sys
import unittest
from io import StringIO
from unittest.mock import MagicMock
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
from mlx_lm.lora import yaml_loader
from mlx_lm.tuner.lora import LoRALinear
from mlx_lm.tuner.utils import build_schedule
class TestLora(unittest.TestCase):
@ -120,5 +124,52 @@ class TestLora(unittest.TestCase):
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
class TestScheduleConfig(unittest.TestCase):
def test_join(self):
config = {"name": "cosine_decay", "warmup": 100, "arguments": [1e-5, 100]}
cos_with_warmup = build_schedule(config)
self.assertIsNotNone(cos_with_warmup)
self.assertEqual(cos_with_warmup(0), 0.0)
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
optimizer = opt.Adam(learning_rate=cos_with_warmup)
for _ in range(100):
optimizer.update({}, {})
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
for _ in range(100):
optimizer.update({}, {})
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
def test_single_schedule(self):
config = {
"name": "cosine_decay",
"arguments": [0.1, 10],
}
lr_schedule = build_schedule(config)
lr = lr_schedule(4)
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
def test_non_zero_warmup(self):
config = {
"name": "cosine_decay",
"warmup": 10,
"warmup_init": 1e-6,
"arguments": [1e-5, 20],
}
lr_schedule = build_schedule(config)
lr = lr_schedule(0)
self.assertAlmostEqual(lr, 1e-6, delta=1e-7)
def test_malformed_config(self):
config = {"warmup": 100}
self.assertRaises(KeyError, build_schedule, config)
config = {"cosine_decay": None}
self.assertRaises(KeyError, build_schedule, config)
if __name__ == "__main__":
unittest.main()