LoRA on all linear transformer block layers (#546)

* Add --lora-all-linear option to apply LoRa to all linear transfer block layers

* Moved to YAML config and added specification of rank & alpha

* nits in conifg, more tests

* nit

* run tests for prs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji
2024-03-12 10:37:40 -04:00
committed by GitHub
parent fe5edee360
commit e56d9015ef
8 changed files with 163 additions and 40 deletions

52
llms/tests/test_lora.py Normal file
View File

@@ -0,0 +1,52 @@
# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_flatten
from mlx_lm import tuner, utils
class TestLora(unittest.TestCase):
def test_to_lora(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
lora_layers = 4
def check_config(params):
n_keys = 2
if "keys" in params:
n_keys = len(params["keys"])
model = llama.Model(args)
model.freeze()
tuner.utils.linear_to_lora_layers(model, lora_layers, params)
trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
self.assertEqual(
trainable_params, lora_layers * params["rank"] * 1024 * 2 * n_keys
)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params)
params["rank"] = 1
check_config(params)
params["keys"] = ["self_attn.k_proj"]
check_config(params)
if __name__ == "__main__":
unittest.main()

View File

@@ -1,8 +1,11 @@
# Copyright © 2024 Apple Inc.
import os
import tempfile
import unittest
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_flatten
from mlx_lm import utils
@@ -11,6 +14,17 @@ HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestUtils(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir_fid.name)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def test_load(self):
model, _ = utils.load(HF_MODEL_PATH)
@@ -40,6 +54,40 @@ class TestUtils(unittest.TestCase):
shards = utils.make_shards(dict(weights), 1)
self.assertTrue(gb <= len(shards) <= gb + 1)
def test_quantize(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
model = llama.Model(args)
weights, config = utils.quantize_model(model, {}, 64, 4)
self.assertTrue("model.layers.2.mlp.up_proj.scales" in weights)
self.assertTrue("model.layers.2.mlp.up_proj.biases" in weights)
self.assertEqual(config["quantization"]["group_size"], 64)
self.assertEqual(config["quantization"]["bits"], 4)
def test_convert(self):
mlx_path = os.path.join(self.test_dir, "mlx_model")
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, quantize=True)
model, _ = utils.load(mlx_path)
self.assertTrue(isinstance(model.layers[0].mlp.up_proj, nn.QuantizedLinear))
self.assertTrue(isinstance(model.layers[-1].mlp.up_proj, nn.QuantizedLinear))
# Check model weights have right type
utils.convert(HF_MODEL_PATH, mlx_path=mlx_path, dtype="bfloat16")
model, _ = utils.load(mlx_path)
self.assertEqual(model.layers[0].mlp.up_proj.weight.dtype, mx.bfloat16)
self.assertEqual(model.layers[-1].mlp.up_proj.weight.dtype, mx.bfloat16)
if __name__ == "__main__":
unittest.main()