mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Feature: QDoRA (#891)
* feat: QDoRA with tests and a small bug fix for recalculation of self.m * some simplifications and fixes --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
aa1c8abdc6
commit
418d9a5511
@ -14,10 +14,11 @@ class DoRALinear(nn.Module):
|
|||||||
dropout: float = 0.0,
|
dropout: float = 0.0,
|
||||||
scale: float = 20.0,
|
scale: float = 20.0,
|
||||||
):
|
):
|
||||||
# TODO support quantized weights in DoRALinear
|
# TODO remove when input_dims and output_dims are attributes
|
||||||
|
# on linear and quantized linear
|
||||||
output_dims, input_dims = linear.weight.shape
|
output_dims, input_dims = linear.weight.shape
|
||||||
if isinstance(linear, nn.QuantizedLinear):
|
if isinstance(linear, nn.QuantizedLinear):
|
||||||
raise ValueError("DoRALinear does not yet support quantization.")
|
input_dims *= 32 // linear.bits
|
||||||
dora_lin = DoRALinear(
|
dora_lin = DoRALinear(
|
||||||
input_dims=input_dims,
|
input_dims=input_dims,
|
||||||
output_dims=output_dims,
|
output_dims=output_dims,
|
||||||
@ -31,13 +32,13 @@ class DoRALinear(nn.Module):
|
|||||||
def fuse(self, de_quantize: bool = False):
|
def fuse(self, de_quantize: bool = False):
|
||||||
linear = self.linear
|
linear = self.linear
|
||||||
bias = "bias" in linear
|
bias = "bias" in linear
|
||||||
weight = linear.weight
|
weight = self._dequantized_weight()
|
||||||
|
|
||||||
# Use the same type as the linear weight if not quantized
|
# Use the same type as the linear weight
|
||||||
dtype = weight.dtype
|
dtype = weight.dtype
|
||||||
|
|
||||||
output_dims, input_dims = weight.shape
|
output_dims, input_dims = weight.shape
|
||||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
fused_linear = nn.Linear(input_dims, output_dims, bias=False)
|
||||||
|
|
||||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||||
lora_a = self.lora_a.T.astype(dtype)
|
lora_a = self.lora_a.T.astype(dtype)
|
||||||
@ -47,6 +48,13 @@ class DoRALinear(nn.Module):
|
|||||||
|
|
||||||
if bias:
|
if bias:
|
||||||
fused_linear.bias = linear.bias
|
fused_linear.bias = linear.bias
|
||||||
|
|
||||||
|
if self._is_quantized() and not de_quantize:
|
||||||
|
fused_linear = nn.QuantizedLinear.from_linear(
|
||||||
|
fused_linear,
|
||||||
|
linear.group_size,
|
||||||
|
linear.bits,
|
||||||
|
)
|
||||||
return fused_linear
|
return fused_linear
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -76,22 +84,45 @@ class DoRALinear(nn.Module):
|
|||||||
)
|
)
|
||||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||||
|
|
||||||
def set_linear(self, linear: nn.Linear):
|
def set_linear(self, linear):
|
||||||
|
"""
|
||||||
|
Set the self.linear layer and recompute self.m.
|
||||||
|
"""
|
||||||
self.linear = linear
|
self.linear = linear
|
||||||
self.m = mx.linalg.norm(self.linear.weight, axis=1)
|
self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
|
||||||
|
|
||||||
|
def _dequantized_weight(self):
|
||||||
|
"""
|
||||||
|
Return the weight of linear layer and dequantize it if is quantized
|
||||||
|
"""
|
||||||
|
weight = self.linear.weight
|
||||||
|
if self._is_quantized():
|
||||||
|
weight = mx.dequantize(
|
||||||
|
weight,
|
||||||
|
self.linear.scales,
|
||||||
|
self.linear.biases,
|
||||||
|
self.linear.group_size,
|
||||||
|
self.linear.bits,
|
||||||
|
)
|
||||||
|
return weight
|
||||||
|
|
||||||
|
def _is_quantized(self):
|
||||||
|
return isinstance(self.linear, nn.QuantizedLinear)
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
# Regular LoRA (without a bias)
|
# Regular LoRA (without a bias)
|
||||||
y = x @ self.linear.weight.T
|
w = self._dequantized_weight()
|
||||||
|
y = x @ w.T
|
||||||
|
|
||||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||||
out = y + (self.scale * z).astype(x.dtype)
|
out = y + (self.scale * z).astype(x.dtype)
|
||||||
|
|
||||||
# Compute the norm of the adapted weights
|
# Compute the norm of the adapted weights
|
||||||
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
|
adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||||
|
|
||||||
# Remove the norm and scale by the learned magnitude
|
# Remove the norm and scale by the learned magnitude
|
||||||
out = (self.m / denom) * out
|
out = (self.m / denom).astype(x.dtype) * out
|
||||||
|
|
||||||
if "bias" in self.linear:
|
if "bias" in self.linear:
|
||||||
out = out + self.linear.bias
|
out = out + self.linear.bias
|
||||||
|
@ -11,7 +11,7 @@ import mlx.nn as nn
|
|||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
from mlx_lm import lora, tuner
|
from mlx_lm import lora, tuner
|
||||||
from mlx_lm.tuner.dora import DoRAEmbedding
|
from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
|
||||||
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
|
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
|
||||||
from mlx_lm.tuner.trainer import evaluate
|
from mlx_lm.tuner.trainer import evaluate
|
||||||
from mlx_lm.tuner.utils import build_schedule
|
from mlx_lm.tuner.utils import build_schedule
|
||||||
@ -164,6 +164,147 @@ class TestDora(unittest.TestCase):
|
|||||||
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
|
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
|
||||||
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
||||||
|
|
||||||
|
def test_llama(self):
|
||||||
|
from mlx_lm.models import llama
|
||||||
|
|
||||||
|
hidden_size = 1024
|
||||||
|
intermediate_size = 2048
|
||||||
|
args = llama.ModelArgs(
|
||||||
|
model_type="llama",
|
||||||
|
hidden_size=hidden_size,
|
||||||
|
num_hidden_layers=4,
|
||||||
|
intermediate_size=intermediate_size,
|
||||||
|
num_attention_heads=4,
|
||||||
|
rms_norm_eps=1e-5,
|
||||||
|
vocab_size=10_000,
|
||||||
|
)
|
||||||
|
|
||||||
|
dora_layers = 4
|
||||||
|
|
||||||
|
def check_config(params):
|
||||||
|
n_keys = 2
|
||||||
|
if "keys" in params:
|
||||||
|
n_keys = len(params["keys"])
|
||||||
|
model = llama.Model(args)
|
||||||
|
model.freeze()
|
||||||
|
tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
|
||||||
|
trainable_params = sum(
|
||||||
|
v.size for _, v in tree_flatten(model.trainable_parameters())
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
trainable_params,
|
||||||
|
dora_layers
|
||||||
|
* (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||||
|
check_config(params)
|
||||||
|
|
||||||
|
params["rank"] = 1
|
||||||
|
check_config(params)
|
||||||
|
|
||||||
|
params["keys"] = ["self_attn.k_proj"]
|
||||||
|
check_config(params)
|
||||||
|
|
||||||
|
def test_dora_m_parameter(self):
|
||||||
|
dora_lin = DoRALinear(input_dims=100, output_dims=100)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
|
||||||
|
)
|
||||||
|
|
||||||
|
# Recomputes m when changing Linear
|
||||||
|
inital_m = dora_lin.m
|
||||||
|
lin = nn.Linear(10, 10)
|
||||||
|
dora_lin.set_linear(lin)
|
||||||
|
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1)))
|
||||||
|
|
||||||
|
# Works with quantized weights
|
||||||
|
quantized_linear = nn.QuantizedLinear(512, 512)
|
||||||
|
dora_lin.set_linear(quantized_linear)
|
||||||
|
dequantized_weight = mx.dequantize(
|
||||||
|
quantized_linear.weight,
|
||||||
|
quantized_linear.scales,
|
||||||
|
quantized_linear.biases,
|
||||||
|
quantized_linear.group_size,
|
||||||
|
quantized_linear.bits,
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dora_from_linear(self):
|
||||||
|
in_dims = 256
|
||||||
|
out_dims = 256
|
||||||
|
r = 4
|
||||||
|
|
||||||
|
linear = nn.Linear(in_dims, out_dims)
|
||||||
|
dora_lin = DoRALinear.from_base(linear, r)
|
||||||
|
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1)))
|
||||||
|
self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
|
||||||
|
self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
|
||||||
|
self.assertEqual(dora_lin.m.shape, (out_dims,))
|
||||||
|
|
||||||
|
quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
|
||||||
|
dequantized_weight = mx.dequantize(
|
||||||
|
quantized_linear.weight,
|
||||||
|
quantized_linear.scales,
|
||||||
|
quantized_linear.biases,
|
||||||
|
quantized_linear.group_size,
|
||||||
|
quantized_linear.bits,
|
||||||
|
)
|
||||||
|
dora_quant_lin = DoRALinear.from_base(quantized_linear, r)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||||
|
)
|
||||||
|
self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
|
||||||
|
self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
|
||||||
|
self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
|
||||||
|
|
||||||
|
def test_dora_to_linear(self):
|
||||||
|
in_dims = 256
|
||||||
|
out_dims = 256
|
||||||
|
r = 4
|
||||||
|
|
||||||
|
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||||
|
dora_lin = DoRALinear.from_base(linear, r)
|
||||||
|
to_linear = dora_lin.fuse()
|
||||||
|
self.assertTrue(mx.allclose(linear.weight, to_linear.weight))
|
||||||
|
self.assertTrue(mx.allclose(linear.bias, to_linear.bias))
|
||||||
|
|
||||||
|
def dequantize_weight(quantized_linear):
|
||||||
|
return mx.dequantize(
|
||||||
|
quantized_linear.weight,
|
||||||
|
quantized_linear.scales,
|
||||||
|
quantized_linear.biases,
|
||||||
|
quantized_linear.group_size,
|
||||||
|
quantized_linear.bits,
|
||||||
|
)
|
||||||
|
|
||||||
|
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
|
||||||
|
dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
|
||||||
|
# Dequantize
|
||||||
|
to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
|
||||||
|
)
|
||||||
|
self.assertTrue(
|
||||||
|
mx.allclose(
|
||||||
|
dequantize_weight(quantized_linear), to_linear_from_quantized.weight
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_dora_dtype(self):
|
||||||
|
in_dims = 256
|
||||||
|
out_dims = 256
|
||||||
|
r = 4
|
||||||
|
|
||||||
|
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||||
|
linear.set_dtype(mx.float16)
|
||||||
|
dora_lin = DoRALinear.from_base(linear, r)
|
||||||
|
|
||||||
|
x = mx.random.uniform(shape=(2, 256)).astype(mx.float16)
|
||||||
|
self.assertEqual(dora_lin(x).dtype, mx.float16)
|
||||||
|
|
||||||
|
|
||||||
class TestScheduleConfig(unittest.TestCase):
|
class TestScheduleConfig(unittest.TestCase):
|
||||||
def test_join(self):
|
def test_join(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user