diff --git a/llms/mlx_lm/fuse.py b/llms/mlx_lm/fuse.py index fa06eb54..16457036 100644 --- a/llms/mlx_lm/fuse.py +++ b/llms/mlx_lm/fuse.py @@ -6,8 +6,8 @@ from pathlib import Path from mlx.utils import tree_flatten, tree_unflatten from .gguf import convert_to_gguf -from .tuner.dora import DoRALinear -from .tuner.lora import LoRALinear, LoRASwitchLinear +from .tuner.dora import DoRAEmbedding, DoRALinear +from .tuner.lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear from .tuner.utils import apply_lora_layers, dequantize from .utils import ( fetch_from_hub, @@ -80,9 +80,11 @@ def main() -> None: model = apply_lora_layers(model, args.adapter_path) fused_linears = [ - (n, m.to_linear()) + (n, m.fuse()) for n, m in model.named_modules() - if isinstance(m, (LoRASwitchLinear, LoRALinear, DoRALinear)) + if isinstance( + m, (LoRASwitchLinear, LoRALinear, LoRAEmbedding, DoRALinear, DoRAEmbedding) + ) ] model.update_modules(tree_unflatten(fused_linears)) diff --git a/llms/mlx_lm/tuner/dora.py b/llms/mlx_lm/tuner/dora.py index de10556b..bd2dfb01 100644 --- a/llms/mlx_lm/tuner/dora.py +++ b/llms/mlx_lm/tuner/dora.py @@ -8,7 +8,7 @@ import mlx.nn as nn class DoRALinear(nn.Module): @staticmethod - def from_linear( + def from_base( linear: nn.Linear, r: int = 8, dropout: float = 0.0, @@ -25,10 +25,10 @@ class DoRALinear(nn.Module): dropout=dropout, scale=scale, ) - dora_lin.linear = linear + dora_lin.set_linear(linear) return dora_lin - def to_linear(self, de_quantize: bool = False): + def fuse(self, de_quantize: bool = False): linear = self.linear bias = "bias" in linear weight = linear.weight @@ -61,7 +61,7 @@ class DoRALinear(nn.Module): super().__init__() # Regular linear layer weights - self.linear = nn.Linear(input_dims, output_dims, bias=bias) + self.set_linear(nn.Linear(input_dims, output_dims, bias=bias)) self.dropout = nn.Dropout(p=dropout) # Scale for low-rank update @@ -75,6 +75,9 @@ class DoRALinear(nn.Module): shape=(input_dims, r), ) self.lora_b = mx.zeros(shape=(r, output_dims)) + + def set_linear(self, linear: nn.Linear): + self.linear = linear self.m = mx.linalg.norm(self.linear.weight, axis=1) def __call__(self, x): @@ -93,3 +96,102 @@ class DoRALinear(nn.Module): if "bias" in self.linear: out = out + self.linear.bias return out + + +class DoRAEmbedding(nn.Module): + def from_base( + embedding: nn.Embedding, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + num_embeddings, dims = embedding.weight.shape + + # TODO support quantized weights in DoRALinear + if isinstance(embedding, nn.QuantizedLinear): + raise ValueError("DoRAEmbedding does not yet support quantization.") + dora_embedding = DoRAEmbedding( + num_embeddings=num_embeddings, + dims=dims, + r=r, + dropout=dropout, + scale=scale, + ) + dora_embedding.set_embedding(embedding) + return dora_embedding + + def fuse(self, de_quantize: bool = False): + embedding = self.embedding + weight = embedding.weight + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + num_embeddings, dims = weight.shape + fused_embedding = nn.Embedding(num_embeddings, dims) + + lora_a = (self.scale * self.lora_a).astype(dtype) + lora_b = self.lora_b.astype(dtype) + weight = weight + lora_a @ lora_b + norm_scale = self.m / mx.linalg.norm(weight, axis=1) + fused_embedding.weight = norm_scale[:, None] * weight + + return fused_embedding + + def __init__( + self, + num_embeddings: int, + dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + super().__init__() + + # Regular embedding layer weights + self.set_embedding(nn.Embedding(num_embeddings, dims)) + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(num_embeddings) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(num_embeddings, r), + ) + self.lora_b = mx.zeros(shape=(r, dims)) + + def set_embedding(self, embedding: nn.Module): + self.embedding = embedding + self.m = mx.linalg.norm(embedding.weight, axis=1) + + def __call__(self, x): + y = self.embedding(x) + z = self.scale * self.lora_a[x] @ self.lora_b + out = y + self.dropout(z).astype(y.dtype) + + # Compute the norm of the adapted weights for the individual embeddings + adapted = y + z + denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=-1)) + + # Remove the norm and scale by the learned magnitude + out = (self.m[x] / denom)[..., None] * out + + return out + + def as_linear(self, x): + y = self.embedding.as_linear(x) + z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T + out = y + (self.scale * z).astype(x.dtype) + + # Compute the norm of the adapted weights + adapted = self.embedding.weight + (self.scale * self.lora_a) @ self.lora_b + denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1)) + + # Remove the norm and scale by the learned magnitude + out = (self.m / denom) * out + + return out diff --git a/llms/mlx_lm/tuner/lora.py b/llms/mlx_lm/tuner/lora.py index 19babb0e..c788cb73 100644 --- a/llms/mlx_lm/tuner/lora.py +++ b/llms/mlx_lm/tuner/lora.py @@ -10,7 +10,7 @@ from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear class LoRALinear(nn.Module): @staticmethod - def from_linear( + def from_base( linear: nn.Linear, r: int = 8, dropout: float = 0.0, @@ -31,7 +31,7 @@ class LoRALinear(nn.Module): lora_lin.linear = linear return lora_lin - def to_linear(self, de_quantize: bool = False): + def fuse(self, de_quantize: bool = False): linear = self.linear bias = "bias" in linear weight = linear.weight @@ -41,7 +41,7 @@ class LoRALinear(nn.Module): dtype = weight.dtype if is_quantized: - dtype = mx.float16 + dtype = linear.scales.dtype weight = mx.dequantize( weight, linear.scales, @@ -103,7 +103,7 @@ class LoRALinear(nn.Module): class LoRASwitchLinear(nn.Module): @staticmethod - def from_linear( + def from_base( linear: nn.Module, r: int = 8, dropout: float = 0.0, @@ -120,7 +120,7 @@ class LoRASwitchLinear(nn.Module): lora_lin.linear = linear return lora_lin - def to_linear(self, de_quantize: bool = False): + def fuse(self, de_quantize: bool = False): linear = self.linear bias = "bias" in linear weight = linear.weight @@ -191,3 +191,95 @@ class LoRASwitchLinear(nn.Module): z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1) return y + (self.scale * z).astype(x.dtype) + + +class LoRAEmbedding(nn.Module): + @staticmethod + def from_base( + embedding: nn.Embedding, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + num_embeddings, dims = embedding.weight.shape + if isinstance(embedding, nn.QuantizedEmbedding): + dims *= 32 // embedding.bits + lora_embedding = LoRAEmbedding( + num_embeddings=num_embeddings, + dims=dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_embedding.embedding = embedding + return lora_embedding + + def fuse(self, de_quantize: bool = False): + embedding = self.embedding + weight = embedding.weight + is_quantized = isinstance(embedding, nn.QuantizedEmbedding) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = embedding.scales.dtype + weight = mx.dequantize( + weight, + embedding.scales, + embedding.biases, + embedding.group_size, + embedding.bits, + ) + num_embeddings, dims = weight.shape + fused_embedding = nn.Embedding(num_embeddings, dims) + + lora_a = (self.scale * self.lora_a).astype(dtype) + lora_b = self.lora_b.astype(dtype) + fused_embedding.weight = weight + lora_a @ lora_b + + if is_quantized and not de_quantize: + fused_embedding = nn.QuantizedEmbedding.from_embedding( + fused_embedding, + embedding.group_size, + embedding.bits, + ) + + return fused_embedding + + def __init__( + self, + num_embeddings: int, + dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + super().__init__() + + # Regular embedding layer + self.embedding = nn.Embedding(num_embeddings, dims) + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(num_embeddings) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(num_embeddings, r), + ) + self.lora_b = mx.zeros(shape=(r, dims)) + + def __call__(self, x): + y = self.embedding(x) + z = self.dropout(self.lora_a[x] @ self.lora_b) + out = y + (self.scale * z).astype(y.dtype) + return out + + def as_linear(self, x): + y = self.embedding.as_linear(x) + z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T + return y + (self.scale * z).astype(x.dtype) diff --git a/llms/mlx_lm/tuner/utils.py b/llms/mlx_lm/tuner/utils.py index 2c97228d..b7f1f9de 100644 --- a/llms/mlx_lm/tuner/utils.py +++ b/llms/mlx_lm/tuner/utils.py @@ -10,8 +10,8 @@ import mlx.optimizers as opt from mlx.utils import tree_flatten, tree_unflatten from ..models.switch_layers import QuantizedSwitchLinear, SwitchLinear -from .dora import DoRALinear -from .lora import LoRALinear, LoRASwitchLinear +from .dora import DoRAEmbedding, DoRALinear +from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear def build_schedule(schedule_config: Dict): @@ -71,12 +71,14 @@ def linear_to_lora_layers( if use_dora: raise ValueError(f"{type(layer).__name__} doesn't support DoRA yet.") LoRALayer = LoRASwitchLinear + elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)): + LoRALayer = DoRAEmbedding if use_dora else LoRAEmbedding else: raise ValueError( f"Can't convert layer of type {type(layer).__name__} to LoRA" ) - return LoRALayer.from_linear( + return LoRALayer.from_base( layer, r=config["rank"], scale=config["scale"], @@ -130,7 +132,12 @@ def linear_to_lora_layers( for l in model.layers[num_layers - num_lora_layers :]: lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] - l.update_modules(tree_unflatten(lora_layers)) + if lora_layers: + l.update_modules(tree_unflatten(lora_layers)) + + lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys] + if lora_modules: + model.update_modules(tree_unflatten(lora_modules)) def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module: diff --git a/llms/tests/test_lora.py b/llms/tests/test_finetune.py similarity index 65% rename from llms/tests/test_lora.py rename to llms/tests/test_finetune.py index f37ae3c2..289b8cfb 100644 --- a/llms/tests/test_lora.py +++ b/llms/tests/test_finetune.py @@ -6,10 +6,13 @@ import unittest from io import StringIO from unittest.mock import MagicMock +import mlx.core as mx +import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_flatten from mlx_lm import lora, tuner -from mlx_lm.tuner.lora import LoRALinear +from mlx_lm.tuner.dora import DoRAEmbedding +from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule @@ -33,11 +36,12 @@ class TestLora(unittest.TestCase): num_attention_heads=4, rms_norm_eps=1e-5, vocab_size=10_000, + tie_word_embeddings=False, ) lora_layers = 4 - def check_config(params): + def check_config(params, expected_trainable_parameters=None): n_keys = 2 if "keys" in params: n_keys = len(params["keys"]) @@ -47,9 +51,11 @@ class TestLora(unittest.TestCase): trainable_params = sum( v.size for _, v in tree_flatten(model.trainable_parameters()) ) - self.assertEqual( - trainable_params, lora_layers * params["rank"] * 1024 * 2 * n_keys + + expected_trainable_parameters = expected_trainable_parameters or ( + lora_layers * params["rank"] * args.hidden_size * 2 * n_keys ) + self.assertEqual(trainable_params, expected_trainable_parameters) params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} check_config(params) @@ -60,6 +66,22 @@ class TestLora(unittest.TestCase): params["keys"] = ["self_attn.k_proj"] check_config(params) + params["keys"] = ["lm_head"] + check_config( + params, + expected_trainable_parameters=( + params["rank"] * (args.hidden_size + args.vocab_size) + ), + ) + + params["keys"] = ["model.embed_tokens"] + check_config( + params, + expected_trainable_parameters=( + params["rank"] * (args.hidden_size + args.vocab_size) + ), + ) + def test_gpt_neox(self): from mlx_lm.models import gpt_neox @@ -82,6 +104,66 @@ class TestLora(unittest.TestCase): model.freeze() tuner.utils.linear_to_lora_layers(model, num_lora_layers, params) + def test_lora_embedding(self): + num_embeddings = 256 + dims = 512 + tokens = mx.array([1, 2, 3]) + + embedding = nn.QuantizedEmbedding(num_embeddings, dims) + dequantized_weight = mx.dequantize( + embedding.weight, + embedding.scales, + embedding.biases, + embedding.group_size, + embedding.bits, + ) + lora_emb = LoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10) + new_embedding = lora_emb.fuse(de_quantize=True) + self.assertTrue(mx.array_equal(dequantized_weight, new_embedding.weight)) + self.assertTrue(mx.array_equal(embedding(tokens), lora_emb(tokens))) + + # as_linear + attn_output = mx.random.uniform(shape=(dims,)) + embedding_lin_out = lora_emb.as_linear(attn_output) + self.assertEqual(embedding_lin_out.shape, (num_embeddings,)) + self.assertTrue( + mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output)) + ) + + # change the value of lora_b and the embeddings will no longer be equal + lora_emb.lora_b = mx.random.uniform(shape=lora_emb.lora_b.shape) + new_embedding = lora_emb.fuse(de_quantize=True) + self.assertFalse(mx.array_equal(dequantized_weight, new_embedding.weight)) + self.assertFalse(mx.array_equal(embedding(tokens), lora_emb(tokens))) + + +class TestDora(unittest.TestCase): + def test_dora_embedding(self): + num_embeddings = 256 + dims = 512 + tokens = mx.array([1, 2, 3]) + + embedding = nn.Embedding(num_embeddings, dims) + + dora_emb = DoRAEmbedding.from_base(embedding, r=8, dropout=0, scale=10) + new_embedding = dora_emb.fuse() + self.assertTrue(mx.array_equal(embedding.weight, new_embedding.weight)) + self.assertTrue(mx.array_equal(embedding(tokens), dora_emb(tokens))) + + # as_linear + attn_output = mx.random.uniform(shape=(dims,)) + embedding_lin_out = dora_emb.as_linear(attn_output) + self.assertEqual(embedding_lin_out.shape, (num_embeddings,)) + self.assertTrue( + mx.array_equal(embedding_lin_out, embedding.as_linear(attn_output)) + ) + + # change the value of lora_b and the embeddings will no longer be equal + dora_emb.lora_b = mx.random.uniform(shape=dora_emb.lora_b.shape) + new_embedding = dora_emb.fuse() + self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight)) + self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens))) + class TestScheduleConfig(unittest.TestCase): def test_join(self):