mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
some simplifications and fixes
This commit is contained in:
@@ -38,7 +38,7 @@ class DoRALinear(nn.Module):
|
||||
dtype = weight.dtype
|
||||
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=False)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
@@ -86,10 +86,10 @@ class DoRALinear(nn.Module):
|
||||
|
||||
def set_linear(self, linear):
|
||||
"""
|
||||
Set the self.linear layer and recompute self.m with respect to quantization
|
||||
Set the self.linear layer and recompute self.m.
|
||||
"""
|
||||
self.linear = linear
|
||||
self.m = mx.linalg.norm(self._dequantized_weight(), axis=1).astype(mx.float32)
|
||||
self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
|
||||
|
||||
def _dequantized_weight(self):
|
||||
"""
|
||||
@@ -111,31 +111,18 @@ class DoRALinear(nn.Module):
|
||||
|
||||
def __call__(self, x):
|
||||
# Regular LoRA (without a bias)
|
||||
if self._is_quantized():
|
||||
# Use quantized_matmul instead of dequantizing for efficiency
|
||||
y = mx.quantized_matmul(
|
||||
x,
|
||||
self.linear.weight,
|
||||
scales=self.linear.scales,
|
||||
biases=self.linear.biases,
|
||||
transpose=True,
|
||||
group_size=self.linear.group_size,
|
||||
bits=self.linear.bits,
|
||||
)
|
||||
else:
|
||||
y = x @ self.linear.weight.T
|
||||
w = self._dequantized_weight()
|
||||
y = x @ w.T
|
||||
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
out = y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
# Compute the norm of the adapted weights
|
||||
adapted = (
|
||||
self._dequantized_weight() + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||
)
|
||||
adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||
|
||||
# Remove the norm and scale by the learned magnitude
|
||||
out = (self.m / denom) * out
|
||||
out = (self.m / denom).astype(x.dtype) * out
|
||||
|
||||
if "bias" in self.linear:
|
||||
out = out + self.linear.bias
|
||||
|
@@ -1,182 +0,0 @@
|
||||
import math
|
||||
import sys
|
||||
import unittest
|
||||
from io import StringIO
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import mlx.optimizers as optim
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import tuner
|
||||
from mlx_lm.tuner.dora import DoRALinear
|
||||
|
||||
|
||||
class TestDora(unittest.TestCase):
|
||||
def setUp(self):
|
||||
self.capturedOutput = StringIO()
|
||||
sys.stdout = self.capturedOutput
|
||||
|
||||
def tearDown(self):
|
||||
sys.stdout = sys.__stdout__
|
||||
|
||||
def test_llama(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
hidden_size = 1024
|
||||
intermediate_size = 2048
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
|
||||
dora_layers = 4
|
||||
|
||||
def check_config(params):
|
||||
n_keys = 2
|
||||
if "keys" in params:
|
||||
n_keys = len(params["keys"])
|
||||
model = llama.Model(args)
|
||||
model.freeze()
|
||||
tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
|
||||
trainable_params = sum(
|
||||
v.size for _, v in tree_flatten(model.trainable_parameters())
|
||||
)
|
||||
self.assertEqual(
|
||||
trainable_params,
|
||||
dora_layers
|
||||
* (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
|
||||
)
|
||||
|
||||
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||
check_config(params)
|
||||
|
||||
params["rank"] = 1
|
||||
check_config(params)
|
||||
|
||||
params["keys"] = ["self_attn.k_proj"]
|
||||
check_config(params)
|
||||
|
||||
def mx_assert_equal(self, a, b):
|
||||
self.assertTrue(mx.array_equal(a, b))
|
||||
|
||||
def mx_assert_not_equal(self, a, b):
|
||||
self.assertFalse(mx.array_equal(a, b))
|
||||
|
||||
def test_dora_m_parameter(self):
|
||||
dora_lin = DoRALinear(input_dims=100, output_dims=100)
|
||||
self.mx_assert_equal(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
|
||||
|
||||
# Recomputes m when changing Linear
|
||||
inital_m = dora_lin.m
|
||||
dora_lin.set_linear(nn.Linear(10, 10))
|
||||
self.mx_assert_not_equal(inital_m, dora_lin.m)
|
||||
self.mx_assert_equal(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
|
||||
|
||||
# Works with quantized weights
|
||||
quantized_linear = nn.QuantizedLinear(512, 512)
|
||||
dora_lin.set_linear(quantized_linear)
|
||||
dequantized_weight = mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
self.mx_assert_equal(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||
|
||||
def test_dora_from_linear(self):
|
||||
in_dims = 1024
|
||||
out_dims = 512
|
||||
r = 4
|
||||
|
||||
linear = nn.Linear(in_dims, out_dims)
|
||||
dora_lin = DoRALinear.from_linear(linear, r)
|
||||
self.mx_assert_equal(dora_lin.m, mx.linalg.norm(linear.weight, axis=1))
|
||||
self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
|
||||
self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
|
||||
self.assertEqual(dora_lin.m.shape, (out_dims,))
|
||||
|
||||
quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
|
||||
dequantized_weight = mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
dora_quant_lin = DoRALinear.from_linear(quantized_linear, r)
|
||||
self.mx_assert_equal(
|
||||
dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1)
|
||||
)
|
||||
self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
|
||||
self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
|
||||
self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
|
||||
|
||||
def test_dora_to_linear(self):
|
||||
in_dims = 1024
|
||||
out_dims = 512
|
||||
r = 4
|
||||
|
||||
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||
dora_lin = DoRALinear.from_linear(linear, r)
|
||||
to_linear = dora_lin.to_linear()
|
||||
self.mx_assert_equal(linear.weight, to_linear.weight)
|
||||
self.mx_assert_equal(linear.bias, to_linear.bias)
|
||||
|
||||
def dequantize_weight(quantized_linear):
|
||||
return mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
|
||||
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
|
||||
dora_quantized_linear = DoRALinear.from_linear(quantized_linear, r)
|
||||
# Dequantize
|
||||
to_linear_from_quantized = dora_quantized_linear.to_linear(de_quantize=True)
|
||||
self.mx_assert_equal(quantized_linear.bias, to_linear_from_quantized.bias)
|
||||
self.mx_assert_equal(
|
||||
dequantize_weight(quantized_linear), to_linear_from_quantized.weight
|
||||
)
|
||||
|
||||
def test_dora_backprop(self):
|
||||
in_dims = 1024
|
||||
out_dims = 512
|
||||
r = 4
|
||||
|
||||
linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
|
||||
dora_lin = DoRALinear.from_linear(linear, r)
|
||||
dora_lin.train()
|
||||
|
||||
input = mx.random.uniform(shape=(in_dims,))
|
||||
target = mx.random.uniform(shape=(out_dims,))
|
||||
|
||||
optimizer = optim.Adam(learning_rate=2e-5)
|
||||
|
||||
def loss_fn(inputs, targets):
|
||||
outputs = dora_lin(inputs)
|
||||
loss = (outputs - targets).square().mean()
|
||||
return loss
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(dora_lin, loss_fn)
|
||||
initial_loss = None
|
||||
for i in range(20):
|
||||
loss, grad = loss_value_and_grad(input, target)
|
||||
self.assertFalse(math.isnan(loss.item()))
|
||||
optimizer.update(dora_lin, grad)
|
||||
|
||||
if i == 0:
|
||||
initial_loss = loss
|
||||
|
||||
self.assertGreater(initial_loss, loss)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
@@ -11,7 +11,7 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx_lm import lora, tuner
|
||||
from mlx_lm.tuner.dora import DoRAEmbedding
|
||||
from mlx_lm.tuner.dora import DoRAEmbedding, DoRALinear
|
||||
from mlx_lm.tuner.lora import LoRAEmbedding, LoRALinear
|
||||
from mlx_lm.tuner.trainer import evaluate
|
||||
from mlx_lm.tuner.utils import build_schedule
|
||||
@@ -164,6 +164,147 @@ class TestDora(unittest.TestCase):
|
||||
self.assertFalse(mx.array_equal(embedding.weight, new_embedding.weight))
|
||||
self.assertFalse(mx.array_equal(embedding(tokens), dora_emb(tokens)))
|
||||
|
||||
def test_llama(self):
|
||||
from mlx_lm.models import llama
|
||||
|
||||
hidden_size = 1024
|
||||
intermediate_size = 2048
|
||||
args = llama.ModelArgs(
|
||||
model_type="llama",
|
||||
hidden_size=hidden_size,
|
||||
num_hidden_layers=4,
|
||||
intermediate_size=intermediate_size,
|
||||
num_attention_heads=4,
|
||||
rms_norm_eps=1e-5,
|
||||
vocab_size=10_000,
|
||||
)
|
||||
|
||||
dora_layers = 4
|
||||
|
||||
def check_config(params):
|
||||
n_keys = 2
|
||||
if "keys" in params:
|
||||
n_keys = len(params["keys"])
|
||||
model = llama.Model(args)
|
||||
model.freeze()
|
||||
tuner.utils.linear_to_lora_layers(model, dora_layers, params, use_dora=True)
|
||||
trainable_params = sum(
|
||||
v.size for _, v in tree_flatten(model.trainable_parameters())
|
||||
)
|
||||
self.assertEqual(
|
||||
trainable_params,
|
||||
dora_layers
|
||||
* (params["rank"] * hidden_size * 2 * n_keys + n_keys * hidden_size),
|
||||
)
|
||||
|
||||
params = {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}
|
||||
check_config(params)
|
||||
|
||||
params["rank"] = 1
|
||||
check_config(params)
|
||||
|
||||
params["keys"] = ["self_attn.k_proj"]
|
||||
check_config(params)
|
||||
|
||||
def test_dora_m_parameter(self):
|
||||
dora_lin = DoRALinear(input_dims=100, output_dims=100)
|
||||
self.assertTrue(
|
||||
mx.allclose(dora_lin.m, mx.linalg.norm(dora_lin.linear.weight, axis=1))
|
||||
)
|
||||
|
||||
# Recomputes m when changing Linear
|
||||
inital_m = dora_lin.m
|
||||
lin = nn.Linear(10, 10)
|
||||
dora_lin.set_linear(lin)
|
||||
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(lin.weight, axis=1)))
|
||||
|
||||
# Works with quantized weights
|
||||
quantized_linear = nn.QuantizedLinear(512, 512)
|
||||
dora_lin.set_linear(quantized_linear)
|
||||
dequantized_weight = mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(dora_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||
)
|
||||
|
||||
def test_dora_from_linear(self):
|
||||
in_dims = 256
|
||||
out_dims = 256
|
||||
r = 4
|
||||
|
||||
linear = nn.Linear(in_dims, out_dims)
|
||||
dora_lin = DoRALinear.from_base(linear, r)
|
||||
self.assertTrue(mx.allclose(dora_lin.m, mx.linalg.norm(linear.weight, axis=1)))
|
||||
self.assertEqual(dora_lin.lora_a.shape, (in_dims, r))
|
||||
self.assertEqual(dora_lin.lora_b.shape, (r, out_dims))
|
||||
self.assertEqual(dora_lin.m.shape, (out_dims,))
|
||||
|
||||
quantized_linear = nn.QuantizedLinear(in_dims, out_dims)
|
||||
dequantized_weight = mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
dora_quant_lin = DoRALinear.from_base(quantized_linear, r)
|
||||
self.assertTrue(
|
||||
mx.allclose(dora_quant_lin.m, mx.linalg.norm(dequantized_weight, axis=1))
|
||||
)
|
||||
self.assertEqual(dora_quant_lin.lora_a.shape, (in_dims, r))
|
||||
self.assertEqual(dora_quant_lin.lora_b.shape, (r, out_dims))
|
||||
self.assertEqual(dora_quant_lin.m.shape, (out_dims,))
|
||||
|
||||
def test_dora_to_linear(self):
|
||||
in_dims = 256
|
||||
out_dims = 256
|
||||
r = 4
|
||||
|
||||
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||
dora_lin = DoRALinear.from_base(linear, r)
|
||||
to_linear = dora_lin.fuse()
|
||||
self.assertTrue(mx.allclose(linear.weight, to_linear.weight))
|
||||
self.assertTrue(mx.allclose(linear.bias, to_linear.bias))
|
||||
|
||||
def dequantize_weight(quantized_linear):
|
||||
return mx.dequantize(
|
||||
quantized_linear.weight,
|
||||
quantized_linear.scales,
|
||||
quantized_linear.biases,
|
||||
quantized_linear.group_size,
|
||||
quantized_linear.bits,
|
||||
)
|
||||
|
||||
quantized_linear = nn.QuantizedLinear(in_dims, out_dims, bias=True)
|
||||
dora_quantized_linear = DoRALinear.from_base(quantized_linear, r)
|
||||
# Dequantize
|
||||
to_linear_from_quantized = dora_quantized_linear.fuse(de_quantize=True)
|
||||
self.assertTrue(
|
||||
mx.allclose(quantized_linear.bias, to_linear_from_quantized.bias)
|
||||
)
|
||||
self.assertTrue(
|
||||
mx.allclose(
|
||||
dequantize_weight(quantized_linear), to_linear_from_quantized.weight
|
||||
)
|
||||
)
|
||||
|
||||
def test_dora_dtype(self):
|
||||
in_dims = 256
|
||||
out_dims = 256
|
||||
r = 4
|
||||
|
||||
linear = nn.Linear(in_dims, out_dims, bias=True)
|
||||
linear.set_dtype(mx.float16)
|
||||
dora_lin = DoRALinear.from_base(linear, r)
|
||||
|
||||
x = mx.random.uniform(shape=(2, 256)).astype(mx.float16)
|
||||
self.assertEqual(dora_lin(x).dtype, mx.float16)
|
||||
|
||||
|
||||
class TestScheduleConfig(unittest.TestCase):
|
||||
def test_join(self):
|
||||
|
Reference in New Issue
Block a user