mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Feature: QDoRA (#891)
* feat: QDoRA with tests and a small bug fix for recalculation of self.m * some simplifications and fixes --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -14,10 +14,11 @@ class DoRALinear(nn.Module):
|
||||
dropout: float = 0.0,
|
||||
scale: float = 20.0,
|
||||
):
|
||||
# TODO support quantized weights in DoRALinear
|
||||
# TODO remove when input_dims and output_dims are attributes
|
||||
# on linear and quantized linear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
raise ValueError("DoRALinear does not yet support quantization.")
|
||||
input_dims *= 32 // linear.bits
|
||||
dora_lin = DoRALinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
@@ -31,13 +32,13 @@ class DoRALinear(nn.Module):
|
||||
def fuse(self, de_quantize: bool = False):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
weight = self._dequantized_weight()
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
# Use the same type as the linear weight
|
||||
dtype = weight.dtype
|
||||
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=False)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
@@ -47,6 +48,13 @@ class DoRALinear(nn.Module):
|
||||
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
|
||||
if self._is_quantized() and not de_quantize:
|
||||
fused_linear = nn.QuantizedLinear.from_linear(
|
||||
fused_linear,
|
||||
linear.group_size,
|
||||
linear.bits,
|
||||
)
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
@@ -76,22 +84,45 @@ class DoRALinear(nn.Module):
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||
|
||||
def set_linear(self, linear: nn.Linear):
|
||||
def set_linear(self, linear):
|
||||
"""
|
||||
Set the self.linear layer and recompute self.m.
|
||||
"""
|
||||
self.linear = linear
|
||||
self.m = mx.linalg.norm(self.linear.weight, axis=1)
|
||||
self.m = mx.linalg.norm(self._dequantized_weight().astype(mx.float32), axis=1)
|
||||
|
||||
def _dequantized_weight(self):
|
||||
"""
|
||||
Return the weight of linear layer and dequantize it if is quantized
|
||||
"""
|
||||
weight = self.linear.weight
|
||||
if self._is_quantized():
|
||||
weight = mx.dequantize(
|
||||
weight,
|
||||
self.linear.scales,
|
||||
self.linear.biases,
|
||||
self.linear.group_size,
|
||||
self.linear.bits,
|
||||
)
|
||||
return weight
|
||||
|
||||
def _is_quantized(self):
|
||||
return isinstance(self.linear, nn.QuantizedLinear)
|
||||
|
||||
def __call__(self, x):
|
||||
# Regular LoRA (without a bias)
|
||||
y = x @ self.linear.weight.T
|
||||
w = self._dequantized_weight()
|
||||
y = x @ w.T
|
||||
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
out = y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
# Compute the norm of the adapted weights
|
||||
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||
adapted = w + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||
|
||||
# Remove the norm and scale by the learned magnitude
|
||||
out = (self.m / denom) * out
|
||||
out = (self.m / denom).astype(x.dtype) * out
|
||||
|
||||
if "bias" in self.linear:
|
||||
out = out + self.linear.bias
|
||||
|
Reference in New Issue
Block a user