mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00

* feature: LoRA adapter for Embeddings * feature: wire in LoRAEmbedding into the tuner. Allow the embedding and non model.layers Linear layers to be targeted for fine tuning * feature: DoRA adapter for Embeddings * feature: wire in DoRAEmbedding * bugfix: ensure self.m is recalculated when the linear layer is changed in DoRALinear.from_linear * refactor: prefer from_base over from_linear or from_embedding. prefer fuse over to_linear or to_embedding * cleanup: remove unused imports in test_dora.py * refactor: remove unnecessary non_layer_modules * cleanup: remove wrong comments for lora embedding dropout. remove uncessary parens in dora embedding dropout * nits --------- Co-authored-by: Awni Hannun <awni@apple.com>
296 lines
9.8 KiB
Python
296 lines
9.8 KiB
Python
# Copyright © 2024 Apple Inc.
|
|
|
|
import math
|
|
import sys
|
|
import unittest
|
|
from io import StringIO
|
|
from unittest.mock import MagicMock
|
|
|
|
import mlx.core as mx
|
|
import mlx.nn as nn
|
|
import mlx.optimizers as opt
|
|
from mlx.utils import tree_flatten
|
|
from mlx_lm import lora, tuner
|
|
from mlx_lm.tuner.dora import DoRAEmbedding
|
|
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
|
|
from mlx_lm.tuner.trainer import evaluate
|
|
from mlx_lm.tuner.utils import build_schedule
|
|
|
|
|
|
class TestLora(unittest.TestCase):
|
|
def setUp(self):
|
|
self.capturedOutput = StringIO()
|
|
sys.stdout = self.capturedOutput
|
|
|
|
def tearDown(self):
|
|
sys.stdout = sys.__stdout__
|
|
|
|
def test_llama(self):
|
|
from mlx_lm.models import llama
|
|
|
|
args = llama.ModelArgs(
|
|
model_type="llama",
|
|
hidden_size=1024,
|
|
num_hidden_layers=4,
|
|
intermediate_size=2048,
|
|
num_attention_heads=4,
|
|
rms_norm_eps=1e-5,
|
|
vocab_size=10_000,
|
|
tie_word_embeddings=False,
|
|
)
|
|
|
|
lora_layers = 4
|
|
|
|
def check_config(params, expected_trainable_parameters=None):
|
|
n_keys = 2
|
|
if "keys" in params:
|
|
n_keys = len(params["keys"])
|
|
model = llama.Model(args)
|
|
model.freeze()
|
|
tuner.utils.linear_to_lora_layers(model, lora_layers, params)
|
|
trainable_params = sum(
|
|
v.size for _, v in tree_flatten(model.trainable_parameters())
|
|
)
|
|
|
|
expected_trainable_parameters = expected_trainable_parameters or (
|
|
lora_layers * params["rank"] * args.hidden_size * 2 * n_keys
|
|
)
|
|
self.assertEqual(trainable_params, expected_trainable_parameters)
|
|
|
|
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
|
check_config(params)
|
|
|
|
params["rank"] = 1
|
|
check_config(params)
|
|
|
|
params["keys"] = ["self_attn.k_proj"]
|
|
check_config(params)
|
|
|
|
params["keys"] = ["lm_head"]
|
|
check_config(
|
|
params,
|
|
expected_trainable_parameters=(
|
|
params["rank"] * (args.hidden_size + args.vocab_size)
|
|
),
|
|
)
|
|
|
|
params["keys"] = ["model.embed_tokens"]
|
|
check_config(
|
|
params,
|
|
expected_trainable_parameters=(
|
|
params["rank"] * (args.hidden_size + args.vocab_size)
|
|
),
|
|
)
|
|
|
|
def test_gpt_neox(self):
|
|
from mlx_lm.models import gpt_neox
|
|
|
|
args = gpt_neox.ModelArgs(
|
|
model_type="gpt_neox",
|
|
max_position_embeddings=2048,
|
|
hidden_size=6144,
|
|
num_attention_heads=64,
|
|
num_hidden_layers=44,
|
|
layer_norm_eps=1e-5,
|
|
vocab_size=50432,
|
|
rotary_emb_base=10_000,
|
|
rotary_pct=0.25,
|
|
)
|
|
|
|
num_lora_layers = 4
|
|
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
|
|
|
model = gpt_neox.Model(args)
|
|
model.freeze()
|
|
tuner.utils.linear_to_lora_layers(model, num_lora_layers, params)
|
|
|
|
def test_lora_embedding(self):
|
|
num_embeddings = 256
|
|
dims = 512
|
|
tokens = mx.array([1, 2, 3])
|
|
|
|
embedding = nn.QuantizedEmbedding(num_embeddings, dims)
|
|
dequantized_weight = mx.dequantize(
|
|
embedding.weight,
|
|
embedding.scales,
|
|
embedding.biases,
|
|
embedding.group_size,
|
|
embedding.bits,
|
|
)
|
|
lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
|
|
new_embedding = lora_emb.fuse(de_quantize=True)
|
|
self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight))
|
|
self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens)))
|
|
|
|
# as_linear
|
|
attn_output = mx.random.uniform(shape=(dims,))
|
|
embedding_lin_out = lora_emb.as_linear(attn_output)
|
|
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
|
|
self.assertTrue(
|
|
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
|
|
)
|
|
|
|
# change the value of lora_b and the embeddings will no longer be equal
|
|
lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape)
|
|
new_embedding = lora_emb.fuse(de_quantize=True)
|
|
self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight))
|
|
self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens)))
|
|
|
|
|
|
class TestDora(unittest.TestCase):
|
|
def test_dora_embedding(self):
|
|
num_embeddings = 256
|
|
dims = 512
|
|
tokens = mx.array([1, 2, 3])
|
|
|
|
embedding = nn.Embedding(num_embeddings, dims)
|
|
|
|
dora_emb = DoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
|
|
new_embedding = dora_emb.fuse()
|
|
self.assertTrue(mx.array_equal(embedding.weight, new_embedding.weight))
|
|
self.assertTrue(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
|
|
|
# as_linear
|
|
attn_output = mx.random.uniform(shape=(dims,))
|
|
embedding_lin_out = dora_emb.as_linear(attn_output)
|
|
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
|
|
self.assertTrue(
|
|
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
|
|
)
|
|
|
|
# change the value of lora_b and the embeddings will no longer be equal
|
|
dora_emb.lora_b = mx.random.uniform(shape=dora_emb.lora_b.shape)
|
|
new_embedding = dora_emb.fuse()
|
|
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
|
|
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
|
|
|
|
|
class TestScheduleConfig(unittest.TestCase):
|
|
def test_join(self):
|
|
config = {"name": "cosine_decay", "warmup": 100, "arguments": [1e-5, 100]}
|
|
cos_with_warmup = build_schedule(config)
|
|
self.assertIsNotNone(cos_with_warmup)
|
|
|
|
self.assertEqual(cos_with_warmup(0), 0.0)
|
|
self.assertAlmostEqual(cos_with_warmup(101), 1e-5, delta=1e-1)
|
|
optimizer = opt.Adam(learning_rate=cos_with_warmup)
|
|
for _ in range(100):
|
|
optimizer.update({}, {})
|
|
self.assertAlmostEqual(optimizer.learning_rate.item(), 1e-5, delta=1e-1)
|
|
for _ in range(100):
|
|
optimizer.update({}, {})
|
|
expected_lr = 1e-5 * 0.5 * (1.0 + math.cos(math.pi * 200 / 10))
|
|
self.assertAlmostEqual(optimizer.learning_rate.item(), expected_lr, delta=1e-1)
|
|
|
|
def test_single_schedule(self):
|
|
|
|
config = {
|
|
"name": "cosine_decay",
|
|
"arguments": [0.1, 10],
|
|
}
|
|
lr_schedule = build_schedule(config)
|
|
lr = lr_schedule(4)
|
|
expected_lr = 0.1 * 0.5 * (1.0 + math.cos(math.pi * 4 / 10))
|
|
self.assertAlmostEqual(lr, expected_lr, delta=1e-7)
|
|
|
|
def test_non_zero_warmup(self):
|
|
config = {
|
|
"name": "cosine_decay",
|
|
"warmup": 10,
|
|
"warmup_init": 1e-6,
|
|
"arguments": [1e-5, 20],
|
|
}
|
|
lr_schedule = build_schedule(config)
|
|
lr = lr_schedule(0)
|
|
self.assertAlmostEqual(lr, 1e-6, delta=1e-7)
|
|
|
|
def test_malformed_config(self):
|
|
config = {"warmup": 100}
|
|
self.assertRaises(KeyError, build_schedule, config)
|
|
|
|
config = {"cosine_decay": None}
|
|
self.assertRaises(KeyError, build_schedule, config)
|
|
|
|
def test_evaluate_calls(self):
|
|
mock_model = MagicMock()
|
|
mock_dataset = MagicMock()
|
|
mock_tokenizer = MagicMock()
|
|
mock_default_loss = MagicMock()
|
|
mock_iterate_batches = MagicMock()
|
|
|
|
mock_iterate_batches.return_value = [
|
|
(MagicMock(), MagicMock()),
|
|
(MagicMock(), MagicMock()),
|
|
(MagicMock(), MagicMock()),
|
|
(MagicMock(), MagicMock()),
|
|
(MagicMock(), MagicMock()),
|
|
]
|
|
|
|
mock_default_loss.side_effect = [
|
|
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
|
|
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
|
|
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
|
(MagicMock(return_value=0.4), MagicMock(return_value=180)),
|
|
(MagicMock(return_value=0.6), MagicMock(return_value=120)),
|
|
]
|
|
evaluate(
|
|
model=mock_model,
|
|
dataset=mock_dataset,
|
|
tokenizer=mock_tokenizer,
|
|
batch_size=2,
|
|
num_batches=2,
|
|
max_seq_length=2048,
|
|
loss=mock_default_loss,
|
|
iterate_batches=mock_iterate_batches,
|
|
)
|
|
|
|
mock_iterate_batches.assert_called_once_with(
|
|
dataset=mock_dataset,
|
|
tokenizer=mock_tokenizer,
|
|
batch_size=2,
|
|
max_seq_length=2048,
|
|
)
|
|
self.assertEqual(mock_default_loss.call_count, 2)
|
|
|
|
def test_evaluate_infinite_batches(self):
|
|
mock_model = MagicMock()
|
|
mock_dataset = MagicMock()
|
|
mock_tokenizer = MagicMock()
|
|
mock_default_loss = MagicMock()
|
|
mock_iterate_batches = MagicMock()
|
|
|
|
mock_iterate_batches.return_value = [
|
|
(MagicMock(), MagicMock()),
|
|
(MagicMock(), MagicMock()),
|
|
(MagicMock(), MagicMock()),
|
|
]
|
|
|
|
mock_default_loss.side_effect = [
|
|
(MagicMock(return_value=0.5), MagicMock(return_value=100)),
|
|
(MagicMock(return_value=0.3), MagicMock(return_value=200)),
|
|
(MagicMock(return_value=0.2), MagicMock(return_value=150)),
|
|
]
|
|
|
|
evaluate(
|
|
model=mock_model,
|
|
dataset=mock_dataset,
|
|
tokenizer=mock_tokenizer,
|
|
batch_size=2,
|
|
num_batches=-1,
|
|
max_seq_length=2048,
|
|
loss=mock_default_loss,
|
|
iterate_batches=mock_iterate_batches,
|
|
)
|
|
|
|
mock_iterate_batches.assert_called_once_with(
|
|
dataset=mock_dataset,
|
|
tokenizer=mock_tokenizer,
|
|
batch_size=2,
|
|
max_seq_length=2048,
|
|
)
|
|
self.assertEqual(mock_default_loss.call_count, 3)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|