From c457a3f88bc7ccfd227bf8c7808cca51dba3e518 Mon Sep 17 00:00:00 2001 From: madroid Date: Sun, 2 Jun 2024 21:38:42 +0800 Subject: [PATCH] LoRA: Extract small function (#614) * LoRA: Extract pre_processing_model function * LoRA: Extract small functions(train_model,evaluate_model) * move test case to test_tuner_utils.py * nits * nits * remove extra param, validate at it 0 * version * fix test --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/examples/lora_config.yaml | 3 +- llms/mlx_lm/lora.py | 168 +++++++++++++------------- llms/mlx_lm/tuner/__init__.py | 2 + llms/mlx_lm/tuner/dora.py | 9 +- llms/mlx_lm/tuner/lora.py | 18 +-- llms/mlx_lm/tuner/trainer.py | 64 +++++----- llms/mlx_lm/tuner/utils.py | 24 +++- llms/mlx_lm/version.py | 2 +- llms/tests/test_lora.py | 63 ---------- llms/tests/test_tuner_utils.py | 85 +++++++++++++ 10 files changed, 232 insertions(+), 206 deletions(-) create mode 100644 llms/tests/test_tuner_utils.py diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 188fd7b5..d3c0d22a 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -60,8 +60,7 @@ lora_parameters: # These will be applied for the last lora_layers keys: ["self_attn.q_proj", "self_attn.v_proj"] rank: 8 - alpha: 16.0 - scale: 10.0 + scale: 20.0 dropout: 0.0 # Schedule can only be specified in a config file, uncomment to use. diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 55abb338..15a3535e 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -10,11 +10,16 @@ import mlx.nn as nn import mlx.optimizers as optim import numpy as np import yaml -from mlx.utils import tree_flatten +from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train -from .tuner.utils import apply_lora_layers, build_schedule, linear_to_lora_layers +from .tuner.utils import ( + apply_lora_layers, + build_schedule, + linear_to_lora_layers, + print_trainable_parameters, +) from .utils import load, save_config yaml_loader = yaml.SafeLoader @@ -33,7 +38,6 @@ yaml_loader.add_implicit_resolver( list("-+0123456789."), ) - CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, @@ -150,111 +154,103 @@ def build_parser(): return parser -def print_trainable_parameters(model): - def nparams(m): - if isinstance(m, nn.QuantizedLinear): - return m.weight.size * (32 // m.bits) - return sum(v.size for _, v in tree_flatten(m.parameters())) +def train_model( + args, + model: nn.Module, + tokenizer: TokenizerWrapper, + train_set, + valid_set, + training_callback: TrainingCallback = None, +): + # Freeze all layers + model.freeze() - leaf_modules = tree_flatten( - model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + # Convert linear layers to lora layers and unfreeze in the process + linear_to_lora_layers(model, args.lora_layers, args.lora_parameters) + + # Resume training the given adapters. + if args.resume_adapter_file is not None: + print(f"Loading pretrained adapters from {resume_adapter_file}") + model.load_weights(args.resume_adapter_file, strict=False) + + print_trainable_parameters(model) + + adapter_path = Path(args.adapter_path) + adapter_path.mkdir(parents=True, exist_ok=True) + adapter_file = adapter_path / "adapters.safetensors" + save_config(vars(args), adapter_path / "adapter_config.json") + + # init training args + training_args = TrainingArgs( + batch_size=args.batch_size, + iters=args.iters, + val_batches=args.val_batches, + steps_per_report=args.steps_per_report, + steps_per_eval=args.steps_per_eval, + steps_per_save=args.save_every, + adapter_file=adapter_file, + max_seq_length=args.max_seq_length, + grad_checkpoint=args.grad_checkpoint, ) - total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 - trainable_p = ( - sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + + model.train() + opt = optim.Adam( + learning_rate=( + build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate + ) ) - print( - f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " - f"({trainable_p:.3f}M/{total_p:.3f}M)" + # Train model + train( + model=model, + tokenizer=tokenizer, + args=training_args, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + training_callback=training_callback, ) +def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set): + model.eval() + + test_loss = evaluate( + model=model, + dataset=test_set, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.test_batches, + max_seq_length=args.max_seq_length, + ) + + test_ppl = math.exp(test_loss) + + print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") + + def run(args, training_callback: TrainingCallback = None): np.random.seed(args.seed) print("Loading pretrained model") model, tokenizer = load(args.model) - # Freeze all layers - model.freeze() - - adapter_path = Path(args.adapter_path) - adapter_file = adapter_path / "adapters.safetensors" + print("Loading datasets") + train_set, valid_set, test_set = load_dataset(args, tokenizer) if args.test and not args.train: # Allow testing without LoRA layers by providing empty path if args.adapter_path != "": - apply_lora_layers(model, adapter_path) - elif args.train: - adapter_path.mkdir(parents=True, exist_ok=True) - save_config(vars(args), adapter_path / "adapter_config.json") + apply_lora_layers(model, args.adapter_path) - # Convert linear layers to lora layers and unfreeze in the process - linear_to_lora_layers( - model, args.lora_layers, args.lora_parameters, args.use_dora - ) - print_trainable_parameters(model) + elif args.train: + print("Training") + train_model(args, model, tokenizer, train_set, valid_set, training_callback) else: raise ValueError("Must provide at least one of --train or --test") - print("Loading datasets") - train_set, valid_set, test_set = load_dataset(args, tokenizer) - - # Resume training the given adapters. - if args.resume_adapter_file is not None: - print(f"Loading pretrained adapters from {args.resume_adapter_file}") - model.load_weights(args.resume_adapter_file, strict=False) - - if args.train: - print("Training") - # init training args - training_args = TrainingArgs( - batch_size=args.batch_size, - iters=args.iters, - val_batches=args.val_batches, - steps_per_report=args.steps_per_report, - steps_per_eval=args.steps_per_eval, - steps_per_save=args.save_every, - adapter_file=adapter_file, - max_seq_length=args.max_seq_length, - grad_checkpoint=args.grad_checkpoint, - ) - - model.train() - opt = optim.Adam( - learning_rate=( - build_schedule(args.lr_schedule) - if args.lr_schedule - else args.learning_rate - ) - ) - # Train model - train( - model=model, - tokenizer=tokenizer, - args=training_args, - optimizer=opt, - train_dataset=train_set, - val_dataset=valid_set, - training_callback=training_callback, - ) - if args.test: print("Testing") - model.eval() - - test_loss = evaluate( - model=model, - dataset=test_set, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.test_batches, - max_seq_length=args.max_seq_length, - ) - - test_ppl = math.exp(test_loss) - - print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") + evaluate_model(args, model, tokenizer, test_set) def main(): diff --git a/llms/mlx_lm/tuner/__init__.py b/llms/mlx_lm/tuner/__init__.py index e69de29b..2e6d2f90 100644 --- a/llms/mlx_lm/tuner/__init__.py +++ b/llms/mlx_lm/tuner/__init__.py @@ -0,0 +1,2 @@ +from .trainer import TrainingArgs, evaluate, train +from .utils import linear_to_lora_layers diff --git a/llms/mlx_lm/tuner/dora.py b/llms/mlx_lm/tuner/dora.py index 364f5b45..de10556b 100644 --- a/llms/mlx_lm/tuner/dora.py +++ b/llms/mlx_lm/tuner/dora.py @@ -11,9 +11,8 @@ class DoRALinear(nn.Module): def from_linear( linear: nn.Linear, r: int = 8, - alpha: float = 16, dropout: float = 0.0, - scale: float = 10.0, + scale: float = 20.0, ): # TODO support quantized weights in DoRALinear output_dims, input_dims = linear.weight.shape @@ -23,7 +22,6 @@ class DoRALinear(nn.Module): input_dims=input_dims, output_dims=output_dims, r=r, - alpha=alpha, dropout=dropout, scale=scale, ) @@ -56,9 +54,8 @@ class DoRALinear(nn.Module): input_dims: int, output_dims: int, r: int = 8, - alpha: float = 16, dropout: float = 0.0, - scale: float = 10.0, + scale: float = 20.0, bias: bool = False, ): super().__init__() @@ -68,7 +65,7 @@ class DoRALinear(nn.Module): self.dropout = nn.Dropout(p=dropout) # Scale for low-rank update - self.scale = scale * (alpha / r) + self.scale = scale # Low rank lora weights scale = 1 / math.sqrt(input_dims) diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py index 22b0b4d0..19babb0e 100644 --- a/llms/mlx_lm/tuner/lora.py +++ b/llms/mlx_lm/tuner/lora.py @@ -13,9 +13,8 @@ class LoRALinear(nn.Module): def from_linear( linear: nn.Linear, r: int = 8, - alpha: float = 16, dropout: float = 0.0, - scale: float = 10.0, + scale: float = 20.0, ): # TODO remove when input_dims and output_dims are attributes # on linear and quantized linear @@ -26,7 +25,6 @@ class LoRALinear(nn.Module): input_dims=input_dims, output_dims=output_dims, r=r, - alpha=alpha, dropout=dropout, scale=scale, ) @@ -74,9 +72,8 @@ class LoRALinear(nn.Module): input_dims: int, output_dims: int, r: int = 8, - alpha: float = 16, dropout: float = 0.0, - scale: float = 10.0, + scale: float = 20.0, bias: bool = False, ): super().__init__() @@ -87,7 +84,7 @@ class LoRALinear(nn.Module): self.dropout = nn.Dropout(p=dropout) # Scale for low-rank update - self.scale = scale * (alpha / r) + self.scale = scale # Low rank lora weights scale = 1 / math.sqrt(input_dims) @@ -109,16 +106,14 @@ class LoRASwitchLinear(nn.Module): def from_linear( linear: nn.Module, r: int = 8, - alpha: float = 16, dropout: float = 0.0, - scale: float = 10.0, + scale: float = 20.0, ): lora_lin = LoRASwitchLinear( input_dims=linear.input_dims, output_dims=linear.output_dims, num_experts=linear.num_experts, r=r, - alpha=alpha, dropout=dropout, scale=scale, ) @@ -163,9 +158,8 @@ class LoRASwitchLinear(nn.Module): output_dims: int, num_experts: int, r: int = 8, - alpha: float = 16, dropout: float = 0.0, - scale: float = 10.0, + scale: float = 20.0, bias: bool = False, ): super().__init__() @@ -176,7 +170,7 @@ class LoRASwitchLinear(nn.Module): self.dropout = nn.Dropout(p=dropout) # Scale for low-rank update - self.scale = scale * (alpha / r) + self.scale = scale # Low rank lora weights scale = 1 / math.sqrt(input_dims) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index f5957782..feecf523 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -29,9 +29,6 @@ def grad_checkpoint(layer): @dataclass class TrainingArgs: - lora_layers: int = field( - default=16, metadata={"help": "Number of layers to fine-tune"} - ) batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) iters: int = field(default=100, metadata={"help": "Iterations to train for."}) val_batches: int = field( @@ -211,6 +208,35 @@ def train( train=True, ), ): + # Report validation loss if needed, the first validation loss + # is always measured before any training. + if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: + stop = time.perf_counter() + val_loss = evaluate( + model=model, + dataset=val_dataset, + loss=loss, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.val_batches, + max_seq_length=args.max_seq_length, + iterate_batches=iterate_batches, + ) + val_time = time.perf_counter() - stop + print( + f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" + ) + + if training_callback is not None: + val_info = { + "iteration": it, + "val_loss": val_loss, + "val_time": val_time, + } + training_callback.on_val_loss_report(val_info) + + start = time.perf_counter() + lvalue, toks = step(batch) mx.eval(state, lvalue, toks) @@ -220,9 +246,9 @@ def train( # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: - train_loss = np.mean(losses) - stop = time.perf_counter() + + train_loss = np.mean(losses) learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) @@ -253,34 +279,6 @@ def train( n_tokens = 0 start = time.perf_counter() - # Report validation loss if needed - if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - stop = time.perf_counter() - val_loss = evaluate( - model=model, - dataset=val_dataset, - loss=loss, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.val_batches, - max_seq_length=args.max_seq_length, - iterate_batches=iterate_batches, - ) - val_time = time.perf_counter() - stop - print( - f"Iter {it}: " f"Val loss {val_loss:.3f}, " f"Val took {val_time:.3f}s" - ) - - if training_callback is not None: - val_info = { - "iteration": it, - "val_loss": val_loss, - "val_time": val_time, - } - training_callback.on_val_loss_report(val_info) - - start = time.perf_counter() - # Save adapter weights if it % args.steps_per_save == 0: save_adapter(model, args.adapter_file) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index d547e018..c0ef4b76 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -7,7 +7,7 @@ from typing import Dict import mlx.core as mx import mlx.nn as nn import mlx.optimizers as opt -from mlx.utils import tree_unflatten +from mlx.utils import tree_flatten, tree_unflatten from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear from .dora import DoRALinear @@ -48,7 +48,7 @@ def linear_to_lora_layers( num_lora_layers (int): The number of blocks to convert to lora layers starting from the last layer. config (dict): More configuration parameters for LoRA, including the - rank, alpha, scale, and optional layer keys. + rank, scale, and optional layer keys. use_dora (bool): If True, uses DoRA instead of LoRA. Default: ``False`` """ @@ -79,7 +79,6 @@ def linear_to_lora_layers( return LoRALayer.from_linear( layer, r=config["rank"], - alpha=config["alpha"], scale=config["scale"], dropout=config["dropout"], ) @@ -218,3 +217,22 @@ def remove_lora_layers(model: nn.Module) -> nn.Module: if len(reset_layers) > 0: model.update_modules(tree_unflatten(reset_layers)) return model + + +def print_trainable_parameters(model): + def nparams(m): + if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): + return m.weight.size * (32 // m.bits) + return sum(v.size for _, v in tree_flatten(m.parameters())) + + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 + trainable_p = ( + sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + ) + print( + f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " + f"({trainable_p:.3f}M/{total_p:.3f}M)" + ) diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index e97f7f0e..086e3505 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.14.1" +__version__ = "0.14.2" diff --git a/llms/tests/test_lora.py b/llms/tests/test_lora.py index 5918c634..1cddaf3a 100644 --- a/llms/tests/test_lora.py +++ b/llms/tests/test_lora.py @@ -6,7 +6,6 @@ import unittest from io import StringIO from unittest.mock import MagicMock -import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_flatten from mlx_lm import lora, tuner @@ -61,68 +60,6 @@ class TestLora(unittest.TestCase): params["keys"] = ["self_attn.k_proj"] check_config(params) - def test_quantized_print_trainable_parameters(self): - model = MagicMock() - quantized_linear = MagicMock(spec=nn.QuantizedLinear) - quantized_linear.weight = MagicMock(size=1e6) - quantized_linear.bits = 8 - lora_linear = MagicMock(spec=LoRALinear) - lora_linear.weight = MagicMock(size=2e6) - lora_linear.parameters.return_value = [lora_linear.weight] - - linear = MagicMock(spec=nn.Linear) - linear.weight = MagicMock(size=3e6) - linear.parameters.return_value = [linear.weight] - - model.leaf_modules.return_value = { - "quantized_linear": quantized_linear, - "lora_linear": lora_linear, - "linear": linear, - } - - model.trainable_parameters.return_value = { - "layer1.weight": MagicMock(size=1e6), - "layer3.weight": MagicMock(size=2e6), - } - expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n" - lora.print_trainable_parameters(model) - self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits) - self.capturedOutput.truncate(0) - self.capturedOutput.seek(0) - - quantized_linear.weight = MagicMock(size=1e6) - quantized_linear.bits = 4 - expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n" - lora.print_trainable_parameters(model) - self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits) - self.capturedOutput.truncate(0) - self.capturedOutput.seek(0) - - def test_print_trainable_parameters(self): - model = MagicMock() - linear1 = MagicMock(spec=nn.Linear) - linear1.weight = MagicMock(size=1e6) - linear1.parameters.return_value = [linear1.weight] - linear2 = MagicMock(spec=nn.Linear) - linear2.weight = MagicMock(size=2e6) - linear2.parameters.return_value = [linear2.weight] - lora_linear = MagicMock(spec=LoRALinear) - lora_linear.weight = MagicMock(size=3e6) - lora_linear.parameters.return_value = [lora_linear.weight] - model.leaf_modules.return_value = { - "linear1": linear1, - "linear2": linear2, - "lora_linear": lora_linear, - } - - model.trainable_parameters.return_value = { - "layer1.weight": MagicMock(size=1e6), - "layer3.weight": MagicMock(size=2e6), - } - expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n" - lora.print_trainable_parameters(model) - self.assertEqual(self.capturedOutput.getvalue(), expected_output) - class TestScheduleConfig(unittest.TestCase): def test_join(self): diff --git a/llms/tests/test_tuner_utils.py b/llms/tests/test_tuner_utils.py new file mode 100644 index 00000000..0256683c --- /dev/null +++ b/llms/tests/test_tuner_utils.py @@ -0,0 +1,85 @@ +# Copyright © 2024 Apple Inc. + +import sys +import unittest +from io import StringIO +from unittest.mock import MagicMock + +import mlx.nn as nn +from mlx_lm.tuner.lora import LoRALinear +from mlx_lm.tuner.utils import print_trainable_parameters + + +class TestTunerUtils(unittest.TestCase): + def setUp(self): + self.capturedOutput = StringIO() + sys.stdout = self.capturedOutput + + def tearDown(self): + sys.stdout = sys.__stdout__ + + def test_quantized_print_trainable_parameters(self): + model = MagicMock() + quantized_linear = MagicMock(spec=nn.QuantizedLinear) + quantized_linear.weight = MagicMock(size=1e6) + quantized_linear.bits = 8 + lora_linear = MagicMock(spec=LoRALinear) + lora_linear.weight = MagicMock(size=2e6) + lora_linear.parameters.return_value = [lora_linear.weight] + + linear = MagicMock(spec=nn.Linear) + linear.weight = MagicMock(size=3e6) + linear.parameters.return_value = [linear.weight] + + model.leaf_modules.return_value = { + "quantized_linear": quantized_linear, + "lora_linear": lora_linear, + "linear": linear, + } + + model.trainable_parameters.return_value = { + "layer1.weight": MagicMock(size=1e6), + "layer3.weight": MagicMock(size=2e6), + } + expected_output_8bits = "Trainable parameters: 33.333% (3.000M/9.000M)\n" + print_trainable_parameters(model) + self.assertEqual(self.capturedOutput.getvalue(), expected_output_8bits) + self.capturedOutput.truncate(0) + self.capturedOutput.seek(0) + + quantized_linear.weight = MagicMock(size=1e6) + quantized_linear.bits = 4 + expected_output_4bits = "Trainable parameters: 23.077% (3.000M/13.000M)\n" + print_trainable_parameters(model) + self.assertEqual(self.capturedOutput.getvalue(), expected_output_4bits) + self.capturedOutput.truncate(0) + self.capturedOutput.seek(0) + + def test_print_trainable_parameters(self): + model = MagicMock() + linear1 = MagicMock(spec=nn.Linear) + linear1.weight = MagicMock(size=1e6) + linear1.parameters.return_value = [linear1.weight] + linear2 = MagicMock(spec=nn.Linear) + linear2.weight = MagicMock(size=2e6) + linear2.parameters.return_value = [linear2.weight] + lora_linear = MagicMock(spec=LoRALinear) + lora_linear.weight = MagicMock(size=3e6) + lora_linear.parameters.return_value = [lora_linear.weight] + model.leaf_modules.return_value = { + "linear1": linear1, + "linear2": linear2, + "lora_linear": lora_linear, + } + + model.trainable_parameters.return_value = { + "layer1.weight": MagicMock(size=1e6), + "layer3.weight": MagicMock(size=2e6), + } + expected_output = "Trainable parameters: 50.000% (3.000M/6.000M)\n" + print_trainable_parameters(model) + self.assertEqual(self.capturedOutput.getvalue(), expected_output) + + +if __name__ == "__main__": + unittest.main()