mlx-examples/llms/tests/test_lora.py
Chime Ogbuji e56d9015ef
LoRA on all linear transformer block layers (#546)
* Add --lora-all-linear option to apply LoRa to all linear transfer block layers

* Moved to YAML config and added specification of rank & alpha

* nits in conifg, more tests

* nit

* run tests for prs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-12 07:37:40 -07:00

53 lines
1.3 KiB
Python

# Copyright © 2024 Apple Inc.
import unittest
import mlx.core as mx
from mlx.utils import tree_flatten
from mlx_lm import tuner, utils
class TestLora(unittest.TestCase):
def test_to_lora(self):
from mlx_lm.models import llama
args = llama.ModelArgs(
model_type="llama",
hidden_size=1024,
num_hidden_layers=4,
intermediate_size=2048,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
lora_layers = 4
def check_config(params):
n_keys = 2
if "keys" in params:
n_keys = len(params["keys"])
model = llama.Model(args)
model.freeze()
tuner.utils.linear_to_lora_layers(model, lora_layers, params)
trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
self.assertEqual(
trainable_params, lora_layers * params["rank"] * 1024 * 2 * n_keys
)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params)
params["rank"] = 1
check_config(params)
params["keys"] = ["self_attn.k_proj"]
check_config(params)
if __name__ == "__main__":
unittest.main()