Allow the entire model to be targed for LoRA and DoRA fine tuning: LoRA and DoRA embeddings with small DoRALinear bug fix (#914)

* feature: LoRA adapter for Embeddings

* feature: wire in LoRAEmbedding into the tuner. Allow the embedding and non model.layers Linear layers to be targeted for fine tuning

* feature: DoRA adapter for Embeddings

* feature: wire in DoRAEmbedding

* bugfix: ensure self.m is recalculated when the linear layer is changed in DoRALinear.from_linear

* refactor: prefer from_base over from_linear or from_embedding. prefer fuse over to_linear or to_embedding

* cleanup: remove unused imports in test_dora.py

* refactor: remove unnecessary non_layer_modules

* cleanup: remove wrong comments for lora embedding dropout. remove uncessary parens in dora embedding dropout

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Zai Thottakath
2024-08-16 09:38:36 -05:00
committed by GitHub
parent c50971e860
commit 4e01700816
5 changed files with 306 additions and 21 deletions

View File

@@ -6,10 +6,13 @@ import unittest
from io import StringIO
from unittest.mock import MagicMock
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
from mlx_lm.tuner.lora import LoRALinear
from mlx_lm.tuner.dora import DoRAEmbedding
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@@ -33,11 +36,12 @@ class TestLora(unittest.TestCase):
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
tie_word_embeddings=False,
)
lora_layers = 4
def check_config(params):
def check_config(params, expected_trainable_parameters=None):
n_keys = 2
if "keys" in params:
n_keys = len(params["keys"])
@@ -47,9 +51,11 @@ class TestLora(unittest.TestCase):
trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
self.assertEqual(
trainable_params, lora_layers * params["rank"] * 1024 * 2 * n_keys
expected_trainable_parameters = expected_trainable_parameters or (
lora_layers * params["rank"] * args.hidden_size * 2 * n_keys
)
self.assertEqual(trainable_params, expected_trainable_parameters)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params)
@@ -60,6 +66,22 @@ class TestLora(unittest.TestCase):
params["keys"] = ["self_attn.k_proj"]
check_config(params)
params["keys"] = ["lm_head"]
check_config(
params,
expected_trainable_parameters=(
params["rank"] * (args.hidden_size + args.vocab_size)
),
)
params["keys"] = ["model.embed_tokens"]
check_config(
params,
expected_trainable_parameters=(
params["rank"] * (args.hidden_size + args.vocab_size)
),
)
def test_gpt_neox(self):
from mlx_lm.models import gpt_neox
@@ -82,6 +104,66 @@ class TestLora(unittest.TestCase):
model.freeze()
tuner.utils.linear_to_lora_layers(model, num_lora_layers, params)
def test_lora_embedding(self):
num_embeddings = 256
dims = 512
tokens = mx.array([1, 2, 3])
embedding = nn.QuantizedEmbedding(num_embeddings, dims)
dequantized_weight = mx.dequantize(
embedding.weight,
embedding.scales,
embedding.biases,
embedding.group_size,
embedding.bits,
)
lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
new_embedding = lora_emb.fuse(de_quantize=True)
self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight))
self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens)))
# as_linear
attn_output = mx.random.uniform(shape=(dims,))
embedding_lin_out = lora_emb.as_linear(attn_output)
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
self.assertTrue(
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
)
# change the value of lora_b and the embeddings will no longer be equal
lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape)
new_embedding = lora_emb.fuse(de_quantize=True)
self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens)))
class TestDora(unittest.TestCase):
def test_dora_embedding(self):
num_embeddings = 256
dims = 512
tokens = mx.array([1, 2, 3])
embedding = nn.Embedding(num_embeddings, dims)
dora_emb = DoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
new_embedding = dora_emb.fuse()
self.assertTrue(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertTrue(mx.array_equal(embedding(tokens), dora_emb(tokens)))
# as_linear
attn_output = mx.random.uniform(shape=(dims,))
embedding_lin_out = dora_emb.as_linear(attn_output)
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
self.assertTrue(
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
)
# change the value of lora_b and the embeddings will no longer be equal
dora_emb.lora_b = mx.random.uniform(shape=dora_emb.lora_b.shape)
new_embedding = dora_emb.fuse()
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
class TestScheduleConfig(unittest.TestCase):
def test_join(self):