Allow the entire model to be targed for LoRA and DoRA fine tuning: LoRA and DoRA embeddings with small DoRALinear bug fix (#914)

* feature: LoRA adapter for Embeddings

* feature: wire in LoRAEmbedding into the tuner. Allow the embedding and non model.layers Linear layers to be targeted for fine tuning

* feature: DoRA adapter for Embeddings

* feature: wire in DoRAEmbedding

* bugfix: ensure self.m is recalculated when the linear layer is changed in DoRALinear.from_linear

* refactor: prefer from_base over from_linear or from_embedding. prefer fuse over to_linear or to_embedding

* cleanup: remove unused imports in test_dora.py

* refactor: remove unnecessary non_layer_modules

* cleanup: remove wrong comments for lora embedding dropout. remove uncessary parens in dora embedding dropout

* nits

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Zai Thottakath 2024-08-16 09:38:36 -05:00 committed by GitHub
parent c50971e860
commit 4e01700816
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 306 additions and 21 deletions

View File

@ -6,8 +6,8 @@ from pathlib import Path
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf from .gguf import convert_to_gguf
from .tuner.dora import DoRALinear from .tuner.dora import DoRAEmbedding, DoRALinear
from .tuner.lora import LoRALinear, LoRASwitchLinear from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
from .tuner.utils import apply_lora_layers, dequantize from .tuner.utils import apply_lora_layers, dequantize
from .utils import ( from .utils import (
fetch_from_hub, fetch_from_hub,
@ -80,9 +80,11 @@ def main() -> None:
model = apply_lora_layers(model, args.adapter_path) model = apply_lora_layers(model, args.adapter_path)
fused_linears = [ fused_linears = [
(n, m.to_linear()) (n, m.fuse())
for n, m in model.named_modules() for n, m in model.named_modules()
if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear)) if isinstance(
m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding)
)
] ]
model.update_modules(tree_unflatten(fused_linears)) model.update_modules(tree_unflatten(fused_linears))

View File

@ -8,7 +8,7 @@ import mlx.nn as nn
class DoRALinear(nn.Module): class DoRALinear(nn.Module):
@staticmethod @staticmethod
def from_linear( def from_base(
linear: nn.Linear, linear: nn.Linear,
r: int = 8, r: int = 8,
dropout: float = 0.0, dropout: float = 0.0,
@ -25,10 +25,10 @@ class DoRALinear(nn.Module):
dropout=dropout, dropout=dropout,
scale=scale, scale=scale,
) )
dora_lin.linear = linear dora_lin.set_linear(linear)
return dora_lin return dora_lin
def to_linear(self, de_quantize: bool = False): def fuse(self, de_quantize: bool = False):
linear = self.linear linear = self.linear
bias = "bias" in linear bias = "bias" in linear
weight = linear.weight weight = linear.weight
@ -61,7 +61,7 @@ class DoRALinear(nn.Module):
super().__init__() super().__init__()
# Regular linear layer weights # Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias) self.set_linear(nn.Linear(input_dims, output_dims, bias=bias))
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update # Scale for low-rank update
@ -75,6 +75,9 @@ class DoRALinear(nn.Module):
shape=(input_dims, r), shape=(input_dims, r),
) )
self.lora_b = mx.zeros(shape=(r, output_dims)) self.lora_b = mx.zeros(shape=(r, output_dims))
def set_linear(self, linear: nn.Linear):
self.linear = linear
self.m = mx.linalg.norm(self.linear.weight, axis=1) self.m = mx.linalg.norm(self.linear.weight, axis=1)
def __call__(self, x): def __call__(self, x):
@ -93,3 +96,102 @@ class DoRALinear(nn.Module):
if "bias" in self.linear: if "bias" in self.linear:
out = out + self.linear.bias out = out + self.linear.bias
return out return out
class DoRAEmbedding(nn.Module):
def from_base(
embedding: nn.Embedding,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
num_embeddings, dims = embedding.weight.shape
# TODO support quantized weights in DoRALinear
if isinstance(embedding, nn.QuantizedLinear):
raise ValueError("DoRAEmbedding does not yet support quantization.")
dora_embedding = DoRAEmbedding(
num_embeddings=num_embeddings,
dims=dims,
r=r,
dropout=dropout,
scale=scale,
)
dora_embedding.set_embedding(embedding)
return dora_embedding
def fuse(self, de_quantize: bool = False):
embedding = self.embedding
weight = embedding.weight
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
num_embeddings, dims = weight.shape
fused_embedding = nn.Embedding(num_embeddings, dims)
lora_a = (self.scale * self.lora_a).astype(dtype)
lora_b = self.lora_b.astype(dtype)
weight = weight + lora_a @ lora_b
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
fused_embedding.weight = norm_scale[:, None] * weight
return fused_embedding
def __init__(
self,
num_embeddings: int,
dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
super().__init__()
# Regular embedding layer weights
self.set_embedding(nn.Embedding(num_embeddings, dims))
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(num_embeddings)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_embeddings, r),
)
self.lora_b = mx.zeros(shape=(r, dims))
def set_embedding(self, embedding: nn.Module):
self.embedding = embedding
self.m = mx.linalg.norm(embedding.weight, axis=1)
def __call__(self, x):
y = self.embedding(x)
z = self.scale * self.lora_a[x] @ self.lora_b
out = y + self.dropout(z).astype(y.dtype)
# Compute the norm of the adapted weights for the individual embeddings
adapted = y + z
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=-1))
# Remove the norm and scale by the learned magnitude
out = (self.m[x] / denom)[..., None] * out
return out
def as_linear(self, x):
y = self.embedding.as_linear(x)
z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
out = y + (self.scale * z).astype(x.dtype)
# Compute the norm of the adapted weights
adapted = self.embedding.weight + (self.scale * self.lora_a) @ self.lora_b
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
# Remove the norm and scale by the learned magnitude
out = (self.m / denom) * out
return out

View File

@ -10,7 +10,7 @@ from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
class LoRALinear(nn.Module): class LoRALinear(nn.Module):
@staticmethod @staticmethod
def from_linear( def from_base(
linear: nn.Linear, linear: nn.Linear,
r: int = 8, r: int = 8,
dropout: float = 0.0, dropout: float = 0.0,
@ -31,7 +31,7 @@ class LoRALinear(nn.Module):
lora_lin.linear = linear lora_lin.linear = linear
return lora_lin return lora_lin
def to_linear(self, de_quantize: bool = False): def fuse(self, de_quantize: bool = False):
linear = self.linear linear = self.linear
bias = "bias" in linear bias = "bias" in linear
weight = linear.weight weight = linear.weight
@ -41,7 +41,7 @@ class LoRALinear(nn.Module):
dtype = weight.dtype dtype = weight.dtype
if is_quantized: if is_quantized:
dtype = mx.float16 dtype = linear.scales.dtype
weight = mx.dequantize( weight = mx.dequantize(
weight, weight,
linear.scales, linear.scales,
@ -103,7 +103,7 @@ class LoRALinear(nn.Module):
class LoRASwitchLinear(nn.Module): class LoRASwitchLinear(nn.Module):
@staticmethod @staticmethod
def from_linear( def from_base(
linear: nn.Module, linear: nn.Module,
r: int = 8, r: int = 8,
dropout: float = 0.0, dropout: float = 0.0,
@ -120,7 +120,7 @@ class LoRASwitchLinear(nn.Module):
lora_lin.linear = linear lora_lin.linear = linear
return lora_lin return lora_lin
def to_linear(self, de_quantize: bool = False): def fuse(self, de_quantize: bool = False):
linear = self.linear linear = self.linear
bias = "bias" in linear bias = "bias" in linear
weight = linear.weight weight = linear.weight
@ -191,3 +191,95 @@ class LoRASwitchLinear(nn.Module):
z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1) z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1)
return y + (self.scale * z).astype(x.dtype) return y + (self.scale * z).astype(x.dtype)
class LoRAEmbedding(nn.Module):
@staticmethod
def from_base(
embedding: nn.Embedding,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
num_embeddings, dims = embedding.weight.shape
if isinstance(embedding, nn.QuantizedEmbedding):
dims *= 32 // embedding.bits
lora_embedding = LoRAEmbedding(
num_embeddings=num_embeddings,
dims=dims,
r=r,
dropout=dropout,
scale=scale,
)
lora_embedding.embedding = embedding
return lora_embedding
def fuse(self, de_quantize: bool = False):
embedding = self.embedding
weight = embedding.weight
is_quantized = isinstance(embedding, nn.QuantizedEmbedding)
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
if is_quantized:
dtype = embedding.scales.dtype
weight = mx.dequantize(
weight,
embedding.scales,
embedding.biases,
embedding.group_size,
embedding.bits,
)
num_embeddings, dims = weight.shape
fused_embedding = nn.Embedding(num_embeddings, dims)
lora_a = (self.scale * self.lora_a).astype(dtype)
lora_b = self.lora_b.astype(dtype)
fused_embedding.weight = weight + lora_a @ lora_b
if is_quantized and not de_quantize:
fused_embedding = nn.QuantizedEmbedding.from_embedding(
fused_embedding,
embedding.group_size,
embedding.bits,
)
return fused_embedding
def __init__(
self,
num_embeddings: int,
dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 20.0,
):
super().__init__()
# Regular embedding layer
self.embedding = nn.Embedding(num_embeddings, dims)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(num_embeddings)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(num_embeddings, r),
)
self.lora_b = mx.zeros(shape=(r, dims))
def __call__(self, x):
y = self.embedding(x)
z = self.dropout(self.lora_a[x] @ self.lora_b)
out = y + (self.scale * z).astype(y.dtype)
return out
def as_linear(self, x):
y = self.embedding.as_linear(x)
z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T
return y + (self.scale * z).astype(x.dtype)

View File

@ -10,8 +10,8 @@ import mlx.optimizers as opt
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear
from .dora import DoRALinear from .dora import DoRAEmbedding, DoRALinear
from .lora import LoRALinear, LoRASwitchLinear from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear
def build_schedule(schedule_config: Dict): def build_schedule(schedule_config: Dict):
@ -71,12 +71,14 @@ def linear_to_lora_layers(
if use_dora: if use_dora:
raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.") raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.")
LoRALayer = LoRASwitchLinear LoRALayer = LoRASwitchLinear
elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)):
LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding
else: else:
raise ValueError( raise ValueError(
f"Can't convert layer of type {type(layer).__name__} to LoRA" f"Can't convert layer of type {type(layer).__name__} to LoRA"
) )
return LoRALayer.from_linear( return LoRALayer.from_base(
layer, layer,
r=config["rank"], r=config["rank"],
scale=config["scale"], scale=config["scale"],
@ -130,8 +132,13 @@ def linear_to_lora_layers(
for l in model.layers[num_layers - num_lora_layers :]: for l in model.layers[num_layers - num_lora_layers :]:
lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys]
if lora_layers:
l.update_modules(tree_unflatten(lora_layers)) l.update_modules(tree_unflatten(lora_layers))
lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys]
if lora_modules:
model.update_modules(tree_unflatten(lora_modules))
def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
""" """

View File

@ -6,10 +6,13 @@ import unittest
from io import StringIO from io import StringIO
from unittest.mock import MagicMock from unittest.mock import MagicMock
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as opt import mlx.optimizers as opt
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from mlx_lm import lora, tuner from mlx_lm import lora, tuner
from mlx_lm.tuner.lora import LoRALinear from mlx_lm.tuner.dora import DoRAEmbedding
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule from mlx_lm.tuner.utils import build_schedule
@ -33,11 +36,12 @@ class TestLora(unittest.TestCase):
num_attention_heads=4, num_attention_heads=4,
rms_norm_eps=1e-5, rms_norm_eps=1e-5,
vocab_size=10_000, vocab_size=10_000,
tie_word_embeddings=False,
) )
lora_layers = 4 lora_layers = 4
def check_config(params): def check_config(params, expected_trainable_parameters=None):
n_keys = 2 n_keys = 2
if "keys" in params: if "keys" in params:
n_keys = len(params["keys"]) n_keys = len(params["keys"])
@ -47,9 +51,11 @@ class TestLora(unittest.TestCase):
trainable_params = sum( trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters()) v.size for _, v in tree_flatten(model.trainable_parameters())
) )
self.assertEqual(
trainable_params, lora_layers * params["rank"] * 1024 * 2 * n_keys expected_trainable_parameters = expected_trainable_parameters or (
lora_layers * params["rank"] * args.hidden_size * 2 * n_keys
) )
self.assertEqual(trainable_params, expected_trainable_parameters)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params) check_config(params)
@ -60,6 +66,22 @@ class TestLora(unittest.TestCase):
params["keys"] = ["self_attn.k_proj"] params["keys"] = ["self_attn.k_proj"]
check_config(params) check_config(params)
params["keys"] = ["lm_head"]
check_config(
params,
expected_trainable_parameters=(
params["rank"] * (args.hidden_size + args.vocab_size)
),
)
params["keys"] = ["model.embed_tokens"]
check_config(
params,
expected_trainable_parameters=(
params["rank"] * (args.hidden_size + args.vocab_size)
),
)
def test_gpt_neox(self): def test_gpt_neox(self):
from mlx_lm.models import gpt_neox from mlx_lm.models import gpt_neox
@ -82,6 +104,66 @@ class TestLora(unittest.TestCase):
model.freeze() model.freeze()
tuner.utils.linear_to_lora_layers(model, num_lora_layers, params) tuner.utils.linear_to_lora_layers(model, num_lora_layers, params)
def test_lora_embedding(self):
num_embeddings = 256
dims = 512
tokens = mx.array([1, 2, 3])
embedding = nn.QuantizedEmbedding(num_embeddings, dims)
dequantized_weight = mx.dequantize(
embedding.weight,
embedding.scales,
embedding.biases,
embedding.group_size,
embedding.bits,
)
lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
new_embedding = lora_emb.fuse(de_quantize=True)
self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight))
self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens)))
# as_linear
attn_output = mx.random.uniform(shape=(dims,))
embedding_lin_out = lora_emb.as_linear(attn_output)
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
self.assertTrue(
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
)
# change the value of lora_b and the embeddings will no longer be equal
lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape)
new_embedding = lora_emb.fuse(de_quantize=True)
self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens)))
class TestDora(unittest.TestCase):
def test_dora_embedding(self):
num_embeddings = 256
dims = 512
tokens = mx.array([1, 2, 3])
embedding = nn.Embedding(num_embeddings, dims)
dora_emb = DoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10)
new_embedding = dora_emb.fuse()
self.assertTrue(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertTrue(mx.array_equal(embedding(tokens), dora_emb(tokens)))
# as_linear
attn_output = mx.random.uniform(shape=(dims,))
embedding_lin_out = dora_emb.as_linear(attn_output)
self.assertEqual(embedding_lin_out.shape, (num_embeddings,))
self.assertTrue(
mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output))
)
# change the value of lora_b and the embeddings will no longer be equal
dora_emb.lora_b = mx.random.uniform(shape=dora_emb.lora_b.shape)
new_embedding = dora_emb.fuse()
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
class TestScheduleConfig(unittest.TestCase): class TestScheduleConfig(unittest.TestCase):
def test_join(self): def test_join(self):