mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Configurable LR schedulers (#604)
* Initial config handler and test * Added means to run from CLI * Update lora config loading and tests * Constrain scheduler config (warmup and minimum LR) for each kind * Update reference to moved schedule_config module * Minor fix * Fix typos * Moved build_schedule and tests * nits in schedule config * flake * fix path --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
b80adbcc3e
commit
f6283ef7ce
@ -16,7 +16,7 @@ lora_layers: 16
|
|||||||
batch_size: 4
|
batch_size: 4
|
||||||
|
|
||||||
# Iterations to train for.
|
# Iterations to train for.
|
||||||
iters: 100
|
iters: 1000
|
||||||
|
|
||||||
# Number of validation batches, -1 uses the entire validation set.
|
# Number of validation batches, -1 uses the entire validation set.
|
||||||
val_batches: 25
|
val_batches: 25
|
||||||
@ -43,7 +43,7 @@ save_every: 100
|
|||||||
test: false
|
test: false
|
||||||
|
|
||||||
# Number of test set batches, -1 uses the entire test set.
|
# Number of test set batches, -1 uses the entire test set.
|
||||||
test_batches: 500
|
test_batches: 100
|
||||||
|
|
||||||
# Maximum sequence length.
|
# Maximum sequence length.
|
||||||
max_seq_length: 2048
|
max_seq_length: 2048
|
||||||
@ -60,3 +60,10 @@ lora_parameters:
|
|||||||
alpha: 16.0
|
alpha: 16.0
|
||||||
scale: 10.0
|
scale: 10.0
|
||||||
dropout: 0.0
|
dropout: 0.0
|
||||||
|
|
||||||
|
# Schedule can only be specified in a config file, uncomment to use.
|
||||||
|
#lr_schedule:
|
||||||
|
# name: cosine_decay
|
||||||
|
# warmup: 100 # 0 for no warmup
|
||||||
|
# warmup_init: 1e-7 # 0 if not specified
|
||||||
|
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
|
||||||
|
@ -15,7 +15,7 @@ from mlx.utils import tree_flatten
|
|||||||
|
|
||||||
from .tuner.datasets import load_dataset
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import linear_to_lora_layers
|
from .tuner.utils import build_schedule, linear_to_lora_layers
|
||||||
from .utils import load
|
from .utils import load
|
||||||
|
|
||||||
yaml_loader = yaml.SafeLoader
|
yaml_loader = yaml.SafeLoader
|
||||||
@ -53,6 +53,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"test": False,
|
"test": False,
|
||||||
"test_batches": 500,
|
"test_batches": 500,
|
||||||
"max_seq_length": 2048,
|
"max_seq_length": 2048,
|
||||||
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -199,7 +200,13 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
opt = optim.Adam(learning_rate=args.learning_rate)
|
opt = optim.Adam(
|
||||||
|
learning_rate=(
|
||||||
|
build_schedule(args.lr_schedule)
|
||||||
|
if args.lr_schedule
|
||||||
|
else args.learning_rate
|
||||||
|
)
|
||||||
|
)
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
train(
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -1,6 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import partial
|
from typing import Optional, Tuple
|
||||||
from typing import Dict, Optional, Tuple, Union
|
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from sys import exit
|
from sys import exit
|
||||||
from typing import Dict, Optional, Tuple, Union
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
@ -1,5 +1,5 @@
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -183,11 +183,7 @@ class PlamoModel(nn.Module):
|
|||||||
mask = mask.astype(self.embed_tokens.weight.dtype)
|
mask = mask.astype(self.embed_tokens.weight.dtype)
|
||||||
|
|
||||||
if cache is None:
|
if cache is None:
|
||||||
past_key_values_length = 0
|
|
||||||
cache = [None for _ in range(len(self.layers.layers))]
|
cache = [None for _ in range(len(self.layers.layers))]
|
||||||
else:
|
|
||||||
if cache[0] is not None:
|
|
||||||
past_key_values_length = cache[0][0].shape[2]
|
|
||||||
|
|
||||||
for e, layer in enumerate(self.layers.layers):
|
for e, layer in enumerate(self.layers.layers):
|
||||||
h, c = layer(h, mask, cache[e])
|
h, c = layer(h, mask, cache[e])
|
||||||
|
@ -3,11 +3,32 @@ from typing import Dict
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as opt
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
from .lora import LoRALinear
|
from .lora import LoRALinear
|
||||||
|
|
||||||
|
|
||||||
|
def build_schedule(schedule_config: Dict):
|
||||||
|
"""
|
||||||
|
Build a learning rate schedule from the given config.
|
||||||
|
"""
|
||||||
|
schedule_fn = getattr(opt.schedulers, schedule_config["name"])
|
||||||
|
arguments = schedule_config["arguments"]
|
||||||
|
initial_lr = arguments[0]
|
||||||
|
bound_schedule_fn = schedule_fn(*arguments)
|
||||||
|
if warmup_steps := schedule_config.get("warmup", 0):
|
||||||
|
warmup_init = schedule_config.get("warmup_init", 0.0)
|
||||||
|
warmup_fn = opt.schedulers.linear_schedule(
|
||||||
|
warmup_init, initial_lr, warmup_steps
|
||||||
|
)
|
||||||
|
return opt.schedulers.join_schedules(
|
||||||
|
[warmup_fn, bound_schedule_fn], [warmup_steps + 1]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return bound_schedule_fn
|
||||||
|
|
||||||
|
|
||||||
def linear_to_lora_layers(
|
def linear_to_lora_layers(
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
num_lora_layers: int,
|
num_lora_layers: int,
|
||||||
|
@ -1,14 +1,18 @@
|
|||||||
# Copyright © 2024 Apple Inc.
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
import sys
|
import sys
|
||||||
import unittest
|
import unittest
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
import mlx.optimizers as opt
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from mlx_lm import lora, tuner
|
from mlx_lm import lora, tuner
|
||||||
|
from mlx_lm.lora import yaml_loader
|
||||||
from mlx_lm.tuner.lora import LoRALinear
|
from mlx_lm.tuner.lora import LoRALinear
|
||||||
|
from mlx_lm.tuner.utils import build_schedule
|
||||||
|
|
||||||
|
|
||||||
class TestLora(unittest.TestCase):
|
class TestLora(unittest.TestCase):
|
||||||
@ -120,5 +124,52 @@ class TestLora(unittest.TestCase):
|
|||||||
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
|
self.assertEqual(self.capturedOutput.getvalue(), expected_output)
|
||||||
|
|
||||||
|
|
||||||
|
class TestScheduleConfig(unittest.TestCase):
|
||||||
|
def test_join(self):
|
||||||
|
config = {"name": "cosine_decay", "warmup": 100, "arguments": [1e-5, 100]}
|
||||||
|
cos_with_warmup = build_schedule(config)
|
||||||
|
self.assertIsNotNone(cos_with_warmup)
|
||||||
|
|
||||||
|
self.assertEqual(cos_with_warmup(0), 0.0)
|
||||||
|
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
|
||||||
|
optimizer = opt.Adam(learning_rate=cos_with_warmup)
|
||||||
|
for _ in range(100):
|
||||||
|
optimizer.update({}, {})
|
||||||
|
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
|
||||||
|
for _ in range(100):
|
||||||
|
optimizer.update({}, {})
|
||||||
|
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
|
||||||
|
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
|
||||||
|
|
||||||
|
def test_single_schedule(self):
|
||||||
|
|
||||||
|
config = {
|
||||||
|
"name": "cosine_decay",
|
||||||
|
"arguments": [0.1, 10],
|
||||||
|
}
|
||||||
|
lr_schedule = build_schedule(config)
|
||||||
|
lr = lr_schedule(4)
|
||||||
|
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
||||||
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
||||||
|
|
||||||
|
def test_non_zero_warmup(self):
|
||||||
|
config = {
|
||||||
|
"name": "cosine_decay",
|
||||||
|
"warmup": 10,
|
||||||
|
"warmup_init": 1e-6,
|
||||||
|
"arguments": [1e-5, 20],
|
||||||
|
}
|
||||||
|
lr_schedule = build_schedule(config)
|
||||||
|
lr = lr_schedule(0)
|
||||||
|
self.assertAlmostEqual(lr, 1e-6, delta=1e-7)
|
||||||
|
|
||||||
|
def test_malformed_config(self):
|
||||||
|
config = {"warmup": 100}
|
||||||
|
self.assertRaises(KeyError, build_schedule, config)
|
||||||
|
|
||||||
|
config = {"cosine_decay": None}
|
||||||
|
self.assertRaises(KeyError, build_schedule, config)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user