diff --git a/llms/mlx_lm/tuner/dora.py b/llms/mlx_lm/tuner/dora.py index bd2dfb01..aba1f6f4 100644 --- a/llms/mlx_lm/tuner/dora.py +++ b/llms/mlx_lm/tuner/dora.py @@ -14,10 +14,11 @@ class DoRALinear(nn.Module): dropout: float = 0.0, scale: float = 20.0, ): - # TODO support quantized weights in DoRALinear + # TODO remove when input_dims and output_dims are attributes + # on linear and quantized linear output_dims, input_dims = linear.weight.shape if isinstance(linear, nn.QuantizedLinear): - raise ValueError("DoRALinear does not yet support quantization.") + input_dims *= 32 // linear.bits dora_lin = DoRALinear( input_dims=input_dims, output_dims=output_dims, @@ -31,13 +32,13 @@ class DoRALinear(nn.Module): def fuse(self, de_quantize: bool = False): linear = self.linear bias = "bias" in linear - weight = linear.weight + weight = self._dequantized_weight() - # Use the same type as the linear weight if not quantized + # Use the same type as the linear weight dtype = weight.dtype output_dims, input_dims = weight.shape - fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + fused_linear = nn.Linear(input_dims, output_dims, bias=False) lora_b = (self.scale * self.lora_b.T).astype(dtype) lora_a = self.lora_a.T.astype(dtype) @@ -47,6 +48,13 @@ class DoRALinear(nn.Module): if bias: fused_linear.bias = linear.bias + + if self._is_quantized() and not de_quantize: + fused_linear = nn.QuantizedLinear.from_linear( + fused_linear, + linear.group_size, + linear.bits, + ) return fused_linear def __init__( @@ -76,22 +84,45 @@ class DoRALinear(nn.Module): ) self.lora_b = mx.zeros(shape=(r, output_dims)) - def set_linear(self, linear: nn.Linear): + def set_linear(self, linear): + """ + Set the self.linear layer and recompute self.m. + """ self.linear = linear - self.m = mx.linalg.norm(self.linear.weight, axis=1) + self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1) + + def _dequantized_weight(self): + """ + Return the weight of linear layer and dequantize it if is quantized + """ + weight = self.linear.weight + if self._is_quantized(): + weight = mx.dequantize( + weight, + self.linear.scales, + self.linear.biases, + self.linear.group_size, + self.linear.bits, + ) + return weight + + def _is_quantized(self): + return isinstance(self.linear, nn.QuantizedLinear) def __call__(self, x): # Regular LoRA (without a bias) - y = x @ self.linear.weight.T + w = self._dequantized_weight() + y = x @ w.T + z = (self.dropout(x) @ self.lora_a) @ self.lora_b out = y + (self.scale * z).astype(x.dtype) # Compute the norm of the adapted weights - adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T + adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1)) # Remove the norm and scale by the learned magnitude - out = (self.m / denom) * out + out = (self.m / denom).astype(x.dtype) * out if "bias" in self.linear: out = out + self.linear.bias diff --git a/llms/tests/test_finetune.py b/llms/tests/test_finetune.py index 289b8cfb..107be092 100644 --- a/llms/tests/test_finetune.py +++ b/llms/tests/test_finetune.py @@ -11,7 +11,7 @@ import mlx.nn as nn import mlx.optimizers as opt from mlx.utils import tree_flatten from mlx_lm import lora, tuner -from mlx_lm.tuner.dora import DoRAEmbedding +from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear from mlx_lm.tuner.trainer import evaluate from mlx_lm.tuner.utils import build_schedule @@ -164,6 +164,147 @@ class TestDora(unittest.TestCase): self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight)) self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens))) + def test_llama(self): + from mlx_lm.models import llama + + hidden_size = 1024 + intermediate_size = 2048 + args = llama.ModelArgs( + model_type="llama", + hidden_size=hidden_size, + num_hidden_layers=4, + intermediate_size=intermediate_size, + num_attention_heads=4, + rms_norm_eps=1e-5, + vocab_size=10_000, + ) + + dora_layers = 4 + + def check_config(params): + n_keys = 2 + if "keys" in params: + n_keys = len(params["keys"]) + model = llama.Model(args) + model.freeze() + tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True) + trainable_params = sum( + v.size for _, v in tree_flatten(model.trainable_parameters()) + ) + self.assertEqual( + trainable_params, + dora_layers + * (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size), + ) + + params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0} + check_config(params) + + params["rank"] = 1 + check_config(params) + + params["keys"] = ["self_attn.k_proj"] + check_config(params) + + def test_dora_m_parameter(self): + dora_lin = DoRALinear(input_dims=100, output_dims=100) + self.assertTrue( + mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1)) + ) + + # Recomputes m when changing Linear + inital_m = dora_lin.m + lin = nn.Linear(10, 10) + dora_lin.set_linear(lin) + self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1))) + + # Works with quantized weights + quantized_linear = nn.QuantizedLinear(512, 512) + dora_lin.set_linear(quantized_linear) + dequantized_weight = mx.dequantize( + quantized_linear.weight, + quantized_linear.scales, + quantized_linear.biases, + quantized_linear.group_size, + quantized_linear.bits, + ) + self.assertTrue( + mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1)) + ) + + def test_dora_from_linear(self): + in_dims = 256 + out_dims = 256 + r = 4 + + linear = nn.Linear(in_dims, out_dims) + dora_lin = DoRALinear.from_base(linear, r) + self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1))) + self.assertEqual(dora_lin.lora_a.shape, (in_dims, r)) + self.assertEqual(dora_lin.lora_b.shape, (r, out_dims)) + self.assertEqual(dora_lin.m.shape, (out_dims,)) + + quantized_linear = nn.QuantizedLinear(in_dims, out_dims) + dequantized_weight = mx.dequantize( + quantized_linear.weight, + quantized_linear.scales, + quantized_linear.biases, + quantized_linear.group_size, + quantized_linear.bits, + ) + dora_quant_lin = DoRALinear.from_base(quantized_linear, r) + self.assertTrue( + mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1)) + ) + self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r)) + self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims)) + self.assertEqual(dora_quant_lin.m.shape, (out_dims,)) + + def test_dora_to_linear(self): + in_dims = 256 + out_dims = 256 + r = 4 + + linear = nn.Linear(in_dims, out_dims, bias=True) + dora_lin = DoRALinear.from_base(linear, r) + to_linear = dora_lin.fuse() + self.assertTrue(mx.allclose(linear.weight, to_linear.weight)) + self.assertTrue(mx.allclose(linear.bias, to_linear.bias)) + + def dequantize_weight(quantized_linear): + return mx.dequantize( + quantized_linear.weight, + quantized_linear.scales, + quantized_linear.biases, + quantized_linear.group_size, + quantized_linear.bits, + ) + + quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True) + dora_quantized_linear = DoRALinear.from_base(quantized_linear, r) + # Dequantize + to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True) + self.assertTrue( + mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias) + ) + self.assertTrue( + mx.allclose( + dequantize_weight(quantized_linear), to_linear_from_quantized.weight + ) + ) + + def test_dora_dtype(self): + in_dims = 256 + out_dims = 256 + r = 4 + + linear = nn.Linear(in_dims, out_dims, bias=True) + linear.set_dtype(mx.float16) + dora_lin = DoRALinear.from_base(linear, r) + + x = mx.random.uniform(shape=(2, 256)).astype(mx.float16) + self.assertEqual(dora_lin(x).dtype, mx.float16) + class TestScheduleConfig(unittest.TestCase): def test_join(self):