Feature: QDoRA (#891)

* feat: QDoRA with tests and a small bug fix for recalculation of self.m

* some simplifications and fixes

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Zai Thottakath 2024-09-30 10:01:11 -05:00 committed by GitHub
parent aa1c8abdc6
commit 418d9a5511
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 183 additions and 11 deletions

View File

@ -14,10 +14,11 @@ class DoRALinear(nn.Module):
dropout: float = 0.0,
scale: float = 20.0,
):
# TODO support quantized weights in DoRALinear
# TODO remove when input_dims and output_dims are attributes
# on linear and quantized linear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
raise ValueError("DoRALinear does not yet support quantization.")
input_dims *= 32 // linear.bits
dora_lin = DoRALinear(
input_dims=input_dims,
output_dims=output_dims,
@ -31,13 +32,13 @@ class DoRALinear(nn.Module):
def fuse(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
weight = self._dequantized_weight()
# Use the same type as the linear weight if not quantized
# Use the same type as the linear weight
dtype = weight.dtype
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
fused_linear = nn.Linear(input_dims, output_dims, bias=False)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
@ -47,6 +48,13 @@ class DoRALinear(nn.Module):
if bias:
fused_linear.bias = linear.bias
if self._is_quantized() and not de_quantize:
fused_linear = nn.QuantizedLinear.from_linear(
fused_linear,
linear.group_size,
linear.bits,
)
return fused_linear
def __init__(
@ -76,22 +84,45 @@ class DoRALinear(nn.Module):
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def set_linear(self, linear: nn.Linear):
def set_linear(self, linear):
"""
Set the self.linear layer and recompute self.m.
"""
self.linear = linear
self.m = mx.linalg.norm(self.linear.weight, axis=1)
self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
def _dequantized_weight(self):
"""
Return the weight of linear layer and dequantize it if is quantized
"""
weight = self.linear.weight
if self._is_quantized():
weight = mx.dequantize(
weight,
self.linear.scales,
self.linear.biases,
self.linear.group_size,
self.linear.bits,
)
return weight
def _is_quantized(self):
return isinstance(self.linear, nn.QuantizedLinear)
def __call__(self, x):
# Regular LoRA (without a bias)
y = x @ self.linear.weight.T
w = self._dequantized_weight()
y = x @ w.T
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
out = y + (self.scale * z).astype(x.dtype)
# Compute the norm of the adapted weights
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
# Remove the norm and scale by the learned magnitude
out = (self.m / denom) * out
out = (self.m / denom).astype(x.dtype) * out
if "bias" in self.linear:
out = out + self.linear.bias

View File

@ -11,7 +11,7 @@ import mlx.nn as nn
import mlx.optimizers as opt
from mlx.utils import tree_flatten
from mlx_lm import lora, tuner
from mlx_lm.tuner.dora import DoRAEmbedding
from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
from mlx_lm.tuner.trainer import evaluate
from mlx_lm.tuner.utils import build_schedule
@ -164,6 +164,147 @@ class TestDora(unittest.TestCase):
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
def test_llama(self):
from mlx_lm.models import llama
hidden_size = 1024
intermediate_size = 2048
args = llama.ModelArgs(
model_type="llama",
hidden_size=hidden_size,
num_hidden_layers=4,
intermediate_size=intermediate_size,
num_attention_heads=4,
rms_norm_eps=1e-5,
vocab_size=10_000,
)
dora_layers = 4
def check_config(params):
n_keys = 2
if "keys" in params:
n_keys = len(params["keys"])
model = llama.Model(args)
model.freeze()
tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
trainable_params = sum(
v.size for _, v in tree_flatten(model.trainable_parameters())
)
self.assertEqual(
trainable_params,
dora_layers
* (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
)
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
check_config(params)
params["rank"] = 1
check_config(params)
params["keys"] = ["self_attn.k_proj"]
check_config(params)
def test_dora_m_parameter(self):
dora_lin = DoRALinear(input_dims=100, output_dims=100)
self.assertTrue(
mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
)
# Recomputes m when changing Linear
inital_m = dora_lin.m
lin = nn.Linear(10, 10)
dora_lin.set_linear(lin)
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1)))
# Works with quantized weights
quantized_linear = nn.QuantizedLinear(512, 512)
dora_lin.set_linear(quantized_linear)
dequantized_weight = mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
self.assertTrue(
mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
)
def test_dora_from_linear(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims)
dora_lin = DoRALinear.from_base(linear, r)
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1)))
self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
self.assertEqual(dora_lin.m.shape, (out_dims,))
quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
dequantized_weight = mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
dora_quant_lin = DoRALinear.from_base(quantized_linear, r)
self.assertTrue(
mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
)
self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
def test_dora_to_linear(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims, bias=True)
dora_lin = DoRALinear.from_base(linear, r)
to_linear = dora_lin.fuse()
self.assertTrue(mx.allclose(linear.weight, to_linear.weight))
self.assertTrue(mx.allclose(linear.bias, to_linear.bias))
def dequantize_weight(quantized_linear):
return mx.dequantize(
quantized_linear.weight,
quantized_linear.scales,
quantized_linear.biases,
quantized_linear.group_size,
quantized_linear.bits,
)
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
# Dequantize
to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
self.assertTrue(
mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
)
self.assertTrue(
mx.allclose(
dequantize_weight(quantized_linear), to_linear_from_quantized.weight
)
)
def test_dora_dtype(self):
in_dims = 256
out_dims = 256
r = 4
linear = nn.Linear(in_dims, out_dims, bias=True)
linear.set_dtype(mx.float16)
dora_lin = DoRALinear.from_base(linear, r)
x = mx.random.uniform(shape=(2, 256)).astype(mx.float16)
self.assertEqual(dora_lin(x).dtype, mx.float16)
class TestScheduleConfig(unittest.TestCase):
def test_join(self):