From f6283ef7ce05b358c1f770567f5927450a7ba66c Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Fri, 29 Mar 2024 16:41:10 -0400 Subject: [PATCH] Configurable LR schedulers (#604) * Initial config handler and test * Added means to run from CLI * Update lora config loading and tests * Constrain scheduler config (warmup and minimum LR) for each kind * Update reference to moved schedule_config module * Minor fix * Fix typos * Moved build_schedule and tests * nits in schedule config * flake * fix path --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/examples/lora_config.yaml | 11 ++++-- llms/mlx_lm/lora.py | 11 ++++-- llms/mlx_lm/models/gemma.py | 3 +- llms/mlx_lm/models/olmo.py | 2 +- llms/mlx_lm/models/plamo.py | 6 +--- llms/mlx_lm/tuner/utils.py | 21 +++++++++++ llms/tests/test_lora.py | 51 +++++++++++++++++++++++++++ 7 files changed, 93 insertions(+), 12 deletions(-) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 1585d69e..90bdd6ad 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -16,7 +16,7 @@ lora_layers: 16 batch_size: 4 # Iterations to train for. -iters: 100 +iters: 1000 # Number of validation batches, -1 uses the entire validation set. val_batches: 25 @@ -43,7 +43,7 @@ save_every: 100 test: false # Number of test set batches, -1 uses the entire test set. -test_batches: 500 +test_batches: 100 # Maximum sequence length. max_seq_length: 2048 @@ -60,3 +60,10 @@ lora_parameters: alpha: 16.0 scale: 10.0 dropout: 0.0 + +# Schedule can only be specified in a config file, uncomment to use. +#lr_schedule: +# name: cosine_decay +# warmup: 100 # 0 for no warmup +# warmup_init: 1e-7 # 0 if not specified +# arguments: [1e-5, 1000, 1e-7] # passed to scheduler diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index adc426e4..9e94868e 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -15,7 +15,7 @@ from mlx.utils import tree_flatten from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train -from .tuner.utils import linear_to_lora_layers +from .tuner.utils import build_schedule, linear_to_lora_layers from .utils import load yaml_loader = yaml.SafeLoader @@ -53,6 +53,7 @@ CONFIG_DEFAULTS = { "test": False, "test_batches": 500, "max_seq_length": 2048, + "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } @@ -199,7 +200,13 @@ def run(args, training_callback: TrainingCallback = None): ) model.train() - opt = optim.Adam(learning_rate=args.learning_rate) + opt = optim.Adam( + learning_rate=( + build_schedule(args.lr_schedule) + if args.lr_schedule + else args.learning_rate + ) + ) # Train model train( model=model, diff --git a/llms/mlx_lm/models/gemma.py b/llms/mlx_lm/models/gemma.py index 0ab99e58..fa6cab9e 100644 --- a/llms/mlx_lm/models/gemma.py +++ b/llms/mlx_lm/models/gemma.py @@ -1,6 +1,5 @@ from dataclasses import dataclass -from functools import partial -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn diff --git a/llms/mlx_lm/models/olmo.py b/llms/mlx_lm/models/olmo.py index b2ceec37..541735de 100644 --- a/llms/mlx_lm/models/olmo.py +++ b/llms/mlx_lm/models/olmo.py @@ -1,6 +1,6 @@ from dataclasses import dataclass from sys import exit -from typing import Dict, Optional, Tuple, Union +from typing import Optional, Tuple import mlx.core as mx import mlx.nn as nn diff --git a/llms/mlx_lm/models/plamo.py b/llms/mlx_lm/models/plamo.py index c4a87a1e..53c1252c 100644 --- a/llms/mlx_lm/models/plamo.py +++ b/llms/mlx_lm/models/plamo.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union +from typing import Any, List, Optional, Tuple, Union import mlx.core as mx import mlx.nn as nn @@ -183,11 +183,7 @@ class PlamoModel(nn.Module): mask = mask.astype(self.embed_tokens.weight.dtype) if cache is None: - past_key_values_length = 0 cache = [None for _ in range(len(self.layers.layers))] - else: - if cache[0] is not None: - past_key_values_length = cache[0][0].shape[2] for e, layer in enumerate(self.layers.layers): h, c = layer(h, mask, cache[e]) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index e7113a36..91990d84 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -3,11 +3,32 @@ from typing import Dict import mlx.core as mx import mlx.nn as nn +import mlx.optimizers as opt from mlx.utils import tree_unflatten from .lora import LoRALinear +def build_schedule(schedule_config: Dict): + """ + Build a learning rate schedule from the given config. + """ + schedule_fn = getattr(opt.schedulers, schedule_config["name"]) + arguments = schedule_config["arguments"] + initial_lr = arguments[0] + bound_schedule_fn = schedule_fn(*arguments) + if warmup_steps := schedule_config.get("warmup", 0): + warmup_init = schedule_config.get("warmup_init", 0.0) + warmup_fn = opt.schedulers.linear_schedule( + warmup_init, initial_lr, warmup_steps + ) + return opt.schedulers.join_schedules( + [warmup_fn, bound_schedule_fn], [warmup_steps + 1] + ) + else: + return bound_schedule_fn + + def linear_to_lora_layers( model: nn.Module, num_lora_layers: int, diff --git a/llms/tests/test_lora.py b/llms/tests/test_lora.py index f7666a42..ccd43fb9 100644 --- a/llms/tests/test_lora.py +++ b/llms/tests/test_lora.py @@ -1,14 +1,18 @@ # Copyright © 2024 Apple Inc. +import math import sys import unittest from io import StringIO from unittest.mock import MagicMock import mlx.nn as nn +import mlx.optimizers as opt from mlx.utils import tree_flatten from mlx_lm import lora, tuner +from mlx_lm.lora import yaml_loader from mlx_lm.tuner.lora import LoRALinear +from mlx_lm.tuner.utils import build_schedule class TestLora(unittest.TestCase): @@ -120,5 +124,52 @@ class TestLora(unittest.TestCase): self.assertEqual(self.capturedOutput.getvalue(), expected_output) +class TestScheduleConfig(unittest.TestCase): + def test_join(self): + config = {"name": "cosine_decay", "warmup": 100, "arguments": [1e-5, 100]} + cos_with_warmup = build_schedule(config) + self.assertIsNotNone(cos_with_warmup) + + self.assertEqual(cos_with_warmup(0), 0.0) + self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1) + optimizer = opt.Adam(learning_rate=cos_with_warmup) + for _ in range(100): + optimizer.update({}, {}) + self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1) + for _ in range(100): + optimizer.update({}, {}) + expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10)) + self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1) + + def test_single_schedule(self): + + config = { + "name": "cosine_decay", + "arguments": [0.1, 10], + } + lr_schedule = build_schedule(config) + lr = lr_schedule(4) + expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10)) + self.assertAlmostEqual(lr, expected_lr, delta=1e-7) + + def test_non_zero_warmup(self): + config = { + "name": "cosine_decay", + "warmup": 10, + "warmup_init": 1e-6, + "arguments": [1e-5, 20], + } + lr_schedule = build_schedule(config) + lr = lr_schedule(0) + self.assertAlmostEqual(lr, 1e-6, delta=1e-7) + + def test_malformed_config(self): + config = {"warmup": 100} + self.assertRaises(KeyError, build_schedule, config) + + config = {"cosine_decay": None} + self.assertRaises(KeyError, build_schedule, config) + + if __name__ == "__main__": unittest.main()