diff --git a/.circleci/config.yml b/.circleci/config.yml index 2ecda2dc..556f209e 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -61,5 +61,7 @@ workflows: type: approval - apple/authenticate: context: pr-approval + - mlx_lm_build_and_test: + requires: [ hold ] - linux_build_and_test: requires: [ hold ] diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index dc324358..32099e0d 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -48,3 +48,12 @@ test_batches: 500 # Maximum sequence length. max_seq_length: 2048 + +# LoRA parameters can only be specified in a config file +lora_parameters: + # The layer keys to apply LoRA to. + # These will be applied for the last lora_layers + keys: ["self_attn.q_proj", "self_attn.v_proj"] + rank: 8 + alpha: 16.0 + scale: 10.0 diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 5f57eb04..e11fed84 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + import argparse import json import math @@ -49,6 +51,7 @@ CONFIG_DEFAULTS = { "test": False, "test_batches": 500, "max_seq_length": 2048, + "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } @@ -58,7 +61,6 @@ def build_parser(): "--model", help="The path to the local model directory or Hugging Face repo.", ) - # Generation args parser.add_argument( "--max-tokens", "-m", @@ -196,7 +198,7 @@ def run(args, training_callback: TrainingCallback = None): # Freeze all layers model.freeze() # Convert linear layers to lora layers and unfreeze in the process - linear_to_lora_layers(model, args.lora_layers) + linear_to_lora_layers(model, args.lora_layers, args.lora_parameters) p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 print(f"Total parameters {p:.3f}M") diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py index d83b9025..2ad0656a 100644 --- a/llms/mlx_lm/tuner/lora.py +++ b/llms/mlx_lm/tuner/lora.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + import math import mlx.core as mx @@ -9,8 +11,8 @@ class LoRALinear(nn.Module): def from_linear( linear: nn.Linear, r: int = 8, - lora_alpha: float = 16, - lora_dropout: float = 0.0, + alpha: float = 16, + dropout: float = 0.0, scale: float = 10.0, ): # TODO remove when input_dims and output_dims are attributes @@ -22,8 +24,8 @@ class LoRALinear(nn.Module): input_dims=input_dims, output_dims=output_dims, r=r, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, + alpha=alpha, + dropout=dropout, scale=scale, ) lora_lin.linear = linear @@ -70,8 +72,8 @@ class LoRALinear(nn.Module): input_dims: int, output_dims: int, r: int = 8, - lora_alpha: float = 16, - lora_dropout: float = 0.0, + alpha: float = 16, + dropout: float = 0.0, scale: float = 10.0, bias: bool = False, ): @@ -80,10 +82,10 @@ class LoRALinear(nn.Module): # Regular linear layer weights self.linear = nn.Linear(input_dims, output_dims, bias=bias) - self.lora_dropout = nn.Dropout(p=lora_dropout) + self.dropout = nn.Dropout(p=dropout) # Scale for low-rank update - self.scale = scale * (lora_alpha / r) + self.scale = scale * (alpha / r) # Low rank lora weights scale = 1 / math.sqrt(input_dims) @@ -99,5 +101,5 @@ class LoRALinear(nn.Module): if isinstance(self.linear, nn.QuantizedLinear): dtype = self.linear.scales.dtype y = self.linear(x.astype(dtype)) - z = (self.lora_dropout(x) @ self.lora_a) @ self.lora_b + z = (self.dropout(x) @ self.lora_a) @ self.lora_b return y + self.scale * z diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 43ab66a6..ec1f40a7 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -1,3 +1,5 @@ +# Copyright © 2024 Apple Inc. + import time from dataclasses import dataclass, field from pathlib import Path diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index bfa5cdf9..355e1699 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -1,4 +1,5 @@ import os +from typing import Dict import mlx.core as mx import mlx.nn as nn @@ -7,7 +8,11 @@ from mlx.utils import tree_unflatten from .lora import LoRALinear -def linear_to_lora_layers(model: nn.Module, num_lora_layers: int): +def linear_to_lora_layers( + model: nn.Module, + num_lora_layers: int, + config: Dict, +): """ Convert some of the models linear layers to lora layers. @@ -15,16 +20,28 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int): model (nn.Module): The neural network model. num_lora_layers (int): The number of blocks to convert to lora layers starting from the last layer. + config (dict): More configuration parameters for LoRA, including the + rank, alpha, scale, and optional layer keys. """ - def check_lora_layers(num_model): - if num_lora_layers > num_model: - raise ValueError( - f"Requested {num_lora_layers} LoRA layers " - f"but the model only has {num_model} layers." - ) + num_layers = len(model.layers) + if num_lora_layers > num_layers: + raise ValueError( + f"Requested {num_lora_layers} LoRA layers " + f"but the model only has {num_layers} layers." + ) - if model.model_type in [ + to_lora = lambda lin: LoRALinear.from_linear( + lin, r=config["rank"], alpha=config["alpha"], scale=config["scale"] + ) + + # If the lora_parameters are set, we assume the keys + # are correct for the given model + + keys = config.get("keys", None) + if keys is not None: + keys = set(keys) + elif model.model_type in [ "mistral", "llama", "phi", @@ -34,32 +51,21 @@ def linear_to_lora_layers(model: nn.Module, num_lora_layers: int): "gemma", "starcoder2", ]: - check_lora_layers(len(model.model.layers)) - - for l in model.model.layers[len(model.model.layers) - num_lora_layers :]: - l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) - l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) - if hasattr(l, "block_sparse_moe"): - l.block_sparse_moe.gate = LoRALinear.from_linear( - l.block_sparse_moe.gate - ) + keys = set(["self_attn.q_proj", "self_attn.v_proj"]) + if model.model_type == "mixtral": + keys.add(["block_sparse_moe.gate"]) elif model.model_type == "olmo": - check_lora_layers(len(model.model.transformer.blocks)) - - for l in model.model.transformer.blocks[ - len(model.model.transformer.blocks) - num_lora_layers : - ]: - l.att_proj = LoRALinear.from_linear(l.att_proj) + keys = set(["att_proj"]) elif model.model_type == "phi-msft": - check_lora_layers(len(model.transformer.h)) - - for l in model.transformer.h[len(model.transformer.h) - num_lora_layers :]: - l.mixer.Wqkv = LoRALinear.from_linear(l.mixer.Wqkv) - l.moe.gate = LoRALinear.from_linear(l.moe.gate) - + keys = set(["mixer.Wqkv", "moe.gate"]) else: raise ValueError(f"Lora does not support {model.model_type}") + for l in model.layers[num_layers - num_lora_layers :]: + modules = l.named_modules() + lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] + l.update_modules(tree_unflatten(lora_layers)) + def apply_lora_layers(model: nn.Module, adapter_file: str) -> nn.Module: """ diff --git a/llms/tests/test_lora.py b/llms/tests/test_lora.py new file mode 100644 index 00000000..ef3ea78e --- /dev/null +++ b/llms/tests/test_lora.py @@ -0,0 +1,52 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.core as mx +from mlx.utils import tree_flatten +from mlx_lm import tuner, utils + + +class TestLora(unittest.TestCase): + + def test_to_lora(self): + from mlx_lm.models import llama + + args = llama.ModelArgs( + model_type="llama", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + + lora_layers = 4 + + def check_config(params): + n_keys = 2 + if "keys" in params: + n_keys = len(params["keys"]) + model = llama.Model(args) + model.freeze() + tuner.utils.linear_to_lora_layers(model, lora_layers, params) + trainable_params = sum( + v.size for _, v in tree_flatten(model.trainable_parameters()) + ) + self.assertEqual( + trainable_params, lora_layers * params["rank"] * 1024 * 2 * n_keys + ) + + params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} + check_config(params) + + params["rank"] = 1 + check_config(params) + + params["keys"] = ["self_attn.k_proj"] + check_config(params) + + +if __name__ == "__main__": + unittest.main() diff --git a/llms/tests/test_utils.py b/llms/tests/test_utils.py index f68bb13c..576c2820 100644 --- a/llms/tests/test_utils.py +++ b/llms/tests/test_utils.py @@ -1,8 +1,11 @@ # Copyright © 2024 Apple Inc. +import os +import tempfile import unittest import mlx.core as mx +import mlx.nn as nn from mlx.utils import tree_flatten from mlx_lm import utils @@ -11,6 +14,17 @@ HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit" class TestUtils(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.test_dir_fid = tempfile.TemporaryDirectory() + cls.test_dir = cls.test_dir_fid.name + if not os.path.isdir(cls.test_dir): + os.mkdir(cls.test_dir_fid.name) + + @classmethod + def tearDownClass(cls): + cls.test_dir_fid.cleanup() + def test_load(self): model, _ = utils.load(HF_MODEL_PATH) @@ -40,6 +54,40 @@ class TestUtils(unittest.TestCase): shards = utils.make_shards(dict(weights), 1) self.assertTrue(gb <= len(shards) <= gb + 1) + def test_quantize(self): + from mlx_lm.models import llama + + args = llama.ModelArgs( + model_type="llama", + hidden_size=1024, + num_hidden_layers=4, + intermediate_size=2048, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + model = llama.Model(args) + weights, config = utils.quantize_model(model, {}, 64, 4) + self.assertTrue("model.layers.2.mlp.up_proj.scales" in weights) + self.assertTrue("model.layers.2.mlp.up_proj.biases" in weights) + self.assertEqual(config["quantization"]["group_size"], 64) + self.assertEqual(config["quantization"]["bits"], 4) + + def test_convert(self): + mlx_path = os.path.join(self.test_dir, "mlx_model") + + utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, quantize=True) + model, _ = utils.load(mlx_path) + self.assertTrue(isinstance(model.layers[0].mlp.up_proj, nn.QuantizedLinear)) + self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear)) + + # Check model weights have right type + utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16") + model, _ = utils.load(mlx_path) + + self.assertEqual(model.layers[0].mlp.up_proj.weight.dtype, mx.bfloat16) + self.assertEqual(model.layers[-1].mlp.up_proj.weight.dtype, mx.bfloat16) + if __name__ == "__main__": unittest.main()