mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
support dora finetune in mlx-examples/llms/mlx_lm (#779)
* support dora finetune * solve problems in lora.py and tuner.utils.py * add use_dora (bool) in functions of load adapters * delete all unsupported quantization code and fix all the calculate problems in mlx_lm/tuner/dora.py * Using stop_gradient to prevent gradients from flowing through ‘norm’ during backpropagation * set DEFAULT_USE_DORA in mlx_lm/generate.py * add annotation for all the use_dora * mlx_lm/fuse.py support fuse dora layers and fix a bug of to_linear() in mlx_lm/tuner/dora.py * simplify code of juding type of a fused layer in mlx_lm/fuse.py * add use_dora in mlx_lm/fuse.py when apply_lora_layers() * style + nits * style + nits * more updates --------- Co-authored-by: chenyifei08 <chenyifei08@baidu.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
98
llms/mlx_lm/tuner/dora.py
Normal file
98
llms/mlx_lm/tuner/dora.py
Normal file
@@ -0,0 +1,98 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
|
||||
class DoRALinear(nn.Module):
|
||||
@staticmethod
|
||||
def from_linear(
|
||||
linear: nn.Linear,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
):
|
||||
# TODO support quantized weights in DoRALinear
|
||||
output_dims, input_dims = linear.weight.shape
|
||||
if isinstance(linear, nn.QuantizedLinear):
|
||||
raise ValueError("DoRALinear does not yet support quantization.")
|
||||
dora_lin = DoRALinear(
|
||||
input_dims=input_dims,
|
||||
output_dims=output_dims,
|
||||
r=r,
|
||||
alpha=alpha,
|
||||
dropout=dropout,
|
||||
scale=scale,
|
||||
)
|
||||
dora_lin.linear = linear
|
||||
return dora_lin
|
||||
|
||||
def to_linear(self, de_quantize: bool = False):
|
||||
linear = self.linear
|
||||
bias = "bias" in linear
|
||||
weight = linear.weight
|
||||
|
||||
# Use the same type as the linear weight if not quantized
|
||||
dtype = weight.dtype
|
||||
|
||||
output_dims, input_dims = weight.shape
|
||||
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
|
||||
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||
lora_a = self.lora_a.T.astype(dtype)
|
||||
weight = weight + lora_b @ lora_a
|
||||
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
|
||||
fused_linear.weight = norm_scale[:, None] * weight
|
||||
|
||||
if bias:
|
||||
fused_linear.bias = linear.bias
|
||||
return fused_linear
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dims: int,
|
||||
output_dims: int,
|
||||
r: int = 8,
|
||||
alpha: float = 16,
|
||||
dropout: float = 0.0,
|
||||
scale: float = 10.0,
|
||||
bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Regular linear layer weights
|
||||
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||
self.dropout = nn.Dropout(p=dropout)
|
||||
|
||||
# Scale for low-rank update
|
||||
self.scale = scale * (alpha / r)
|
||||
|
||||
# Low rank lora weights
|
||||
scale = 1 / math.sqrt(input_dims)
|
||||
self.lora_a = mx.random.uniform(
|
||||
low=-scale,
|
||||
high=scale,
|
||||
shape=(input_dims, r),
|
||||
)
|
||||
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||
self.m = mx.linalg.norm(self.linear.weight, axis=1)
|
||||
|
||||
def __call__(self, x):
|
||||
# Regular LoRA (without a bias)
|
||||
y = x @ self.linear.weight.T
|
||||
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||
out = y + (self.scale * z).astype(x.dtype)
|
||||
|
||||
# Compute the norm of the adapted weights
|
||||
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||
|
||||
# Remove the norm and scale by the learned magnitude
|
||||
out = (self.m / denom) * out
|
||||
|
||||
if "bias" in self.linear:
|
||||
out = out + self.linear.bias
|
||||
return out
|
@@ -9,6 +9,7 @@ import mlx.nn as nn
|
||||
import mlx.optimizers as opt
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .dora import DoRALinear
|
||||
from .lora import LoRALinear
|
||||
|
||||
|
||||
@@ -36,6 +37,7 @@ def linear_to_lora_layers(
|
||||
model: nn.Module,
|
||||
num_lora_layers: int,
|
||||
config: Dict,
|
||||
use_dora: bool = False,
|
||||
):
|
||||
"""
|
||||
Convert some of the models linear layers to lora layers.
|
||||
@@ -46,6 +48,8 @@ def linear_to_lora_layers(
|
||||
starting from the last layer.
|
||||
config (dict): More configuration parameters for LoRA, including the
|
||||
rank, alpha, scale, and optional layer keys.
|
||||
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||
Default: ``False``
|
||||
"""
|
||||
|
||||
num_layers = len(model.layers)
|
||||
@@ -54,14 +58,16 @@ def linear_to_lora_layers(
|
||||
f"Requested {num_lora_layers} LoRA layers "
|
||||
f"but the model only has {num_layers} layers."
|
||||
)
|
||||
cls = DoRALinear if use_dora else LoRALinear
|
||||
|
||||
to_lora = lambda lin: LoRALinear.from_linear(
|
||||
lin,
|
||||
r=config["rank"],
|
||||
alpha=config["alpha"],
|
||||
scale=config["scale"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
def to_lora(lin):
|
||||
return cls.from_linear(
|
||||
lin,
|
||||
r=config["rank"],
|
||||
alpha=config["alpha"],
|
||||
scale=config["scale"],
|
||||
dropout=config["dropout"],
|
||||
)
|
||||
|
||||
keys = config.get("keys", None)
|
||||
if keys is not None:
|
||||
@@ -119,7 +125,12 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
||||
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
||||
with open(adapter_path / "adapter_config.json", "r") as fid:
|
||||
config = types.SimpleNamespace(**json.load(fid))
|
||||
linear_to_lora_layers(model, config.lora_layers, config.lora_parameters)
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
config.lora_layers,
|
||||
config.lora_parameters,
|
||||
getattr(config, "use_dora", False),
|
||||
)
|
||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||
return model
|
||||
|
||||
|
Reference in New Issue
Block a user