support dora finetune in mlx-examples/llms/mlx_lm (#779)

* support dora finetune

* solve problems in lora.py and tuner.utils.py

* add use_dora (bool) in functions of load adapters

* delete all unsupported quantization code and fix all the calculate problems in mlx_lm/tuner/dora.py

* Using stop_gradient to prevent gradients from flowing through ‘norm’ during backpropagation

* set DEFAULT_USE_DORA in mlx_lm/generate.py

* add annotation for all the use_dora

* mlx_lm/fuse.py support fuse dora layers and fix a bug of to_linear() in mlx_lm/tuner/dora.py

* simplify code of juding type of a fused layer in mlx_lm/fuse.py

* add use_dora in mlx_lm/fuse.py when apply_lora_layers()

* style + nits

* style + nits

* more updates

---------

Co-authored-by: chenyifei08 <chenyifei08@baidu.com>
Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
alexC-nonsense4k 2024-05-16 23:21:26 +08:00 committed by GitHub
parent 69181e0058
commit 42458914c8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 147 additions and 19 deletions

View File

@ -51,6 +51,9 @@ max_seq_length: 2048
# Use gradient checkpointing to reduce memory use. # Use gradient checkpointing to reduce memory use.
grad_checkpoint: false grad_checkpoint: false
# Use DoRA instead of LoRA.
use_dora: false
# LoRA parameters can only be specified in a config file # LoRA parameters can only be specified in a config file
lora_parameters: lora_parameters:
# The layer keys to apply LoRA to. # The layer keys to apply LoRA to.

View File

@ -6,6 +6,7 @@ from pathlib import Path
from mlx.utils import tree_flatten, tree_unflatten from mlx.utils import tree_flatten, tree_unflatten
from .gguf import convert_to_gguf from .gguf import convert_to_gguf
from .tuner.dora import DoRALinear
from .tuner.lora import LoRALinear from .tuner.lora import LoRALinear
from .tuner.utils import apply_lora_layers, dequantize from .tuner.utils import apply_lora_layers, dequantize
from .utils import ( from .utils import (
@ -18,7 +19,9 @@ from .utils import (
def parse_arguments() -> argparse.Namespace: def parse_arguments() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser = argparse.ArgumentParser(
description="Fuse fine-tuned adapters into the base model."
)
parser.add_argument( parser.add_argument(
"--model", "--model",
default="mlx_model", default="mlx_model",
@ -79,7 +82,7 @@ def main() -> None:
fused_linears = [ fused_linears = [
(n, m.to_linear()) (n, m.to_linear())
for n, m in model.named_modules() for n, m in model.named_modules()
if isinstance(m, LoRALinear) if isinstance(m, (LoRALinear, DoRALinear))
] ]
model.update_modules(tree_unflatten(fused_linears)) model.update_modules(tree_unflatten(fused_linears))

View File

@ -123,7 +123,9 @@ def main():
tokenizer_config["eos_token"] = args.eos_token tokenizer_config["eos_token"] = args.eos_token
model, tokenizer = load( model, tokenizer = load(
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config args.model,
adapter_path=args.adapter_path,
tokenizer_config=tokenizer_config,
) )
if args.use_default_chat_template: if args.use_default_chat_template:

View File

@ -54,6 +54,7 @@ CONFIG_DEFAULTS = {
"max_seq_length": 2048, "max_seq_length": 2048,
"lr_schedule": None, "lr_schedule": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"use_dora": False,
} }
@ -69,6 +70,7 @@ def build_parser():
"--train", "--train",
action="store_true", action="store_true",
help="Do training", help="Do training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--data", "--data",
@ -117,6 +119,7 @@ def build_parser():
"--test", "--test",
action="store_true", action="store_true",
help="Evaluate on the test set after training", help="Evaluate on the test set after training",
default=None,
) )
parser.add_argument( parser.add_argument(
"--test-batches", "--test-batches",
@ -138,8 +141,12 @@ def build_parser():
"--grad-checkpoint", "--grad-checkpoint",
action="store_true", action="store_true",
help="Use gradient checkpointing to reduce memory use.", help="Use gradient checkpointing to reduce memory use.",
default=None,
) )
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
parser.add_argument(
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
)
return parser return parser
@ -175,16 +182,20 @@ def run(args, training_callback: TrainingCallback = None):
adapter_file = adapter_path / "adapters.safetensors" adapter_file = adapter_path / "adapters.safetensors"
if args.test and not args.train: if args.test and not args.train:
# Allow testing without LoRA layers by providing empty path
if args.adapter_path != "":
apply_lora_layers(model, adapter_path) apply_lora_layers(model, adapter_path)
elif args.train:
else:
adapter_path.mkdir(parents=True, exist_ok=True) adapter_path.mkdir(parents=True, exist_ok=True)
save_config(vars(args), adapter_path / "adapter_config.json") save_config(vars(args), adapter_path / "adapter_config.json")
# Convert linear layers to lora layers and unfreeze in the process # Convert linear layers to lora layers and unfreeze in the process
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters) linear_to_lora_layers(
model, args.lora_layers, args.lora_parameters, args.use_dora
)
print_trainable_parameters(model) print_trainable_parameters(model)
else:
raise ValueError("Must provide at least one of --train or --test")
print("Loading datasets") print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args, tokenizer) train_set, valid_set, test_set = load_dataset(args, tokenizer)
@ -257,12 +268,12 @@ def main():
config = yaml.load(file, yaml_loader) config = yaml.load(file, yaml_loader)
# Prefer parameters from command-line arguments # Prefer parameters from command-line arguments
for k, v in config.items(): for k, v in config.items():
if not args.get(k, None): if args.get(k, None) is not None:
args[k] = v args[k] = v
# Update defaults for unspecified parameters # Update defaults for unspecified parameters
for k, v in CONFIG_DEFAULTS.items(): for k, v in CONFIG_DEFAULTS.items():
if not args.get(k, None): if args.get(k, None) is None:
args[k] = v args[k] = v
run(types.SimpleNamespace(**args)) run(types.SimpleNamespace(**args))

98
llms/mlx_lm/tuner/dora.py Normal file
View File

@ -0,0 +1,98 @@
# Copyright © 2024 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
class DoRALinear(nn.Module):
@staticmethod
def from_linear(
linear: nn.Linear,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
):
# TODO support quantized weights in DoRALinear
output_dims, input_dims = linear.weight.shape
if isinstance(linear, nn.QuantizedLinear):
raise ValueError("DoRALinear does not yet support quantization.")
dora_lin = DoRALinear(
input_dims=input_dims,
output_dims=output_dims,
r=r,
alpha=alpha,
dropout=dropout,
scale=scale,
)
dora_lin.linear = linear
return dora_lin
def to_linear(self, de_quantize: bool = False):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
# Use the same type as the linear weight if not quantized
dtype = weight.dtype
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = (self.scale * self.lora_b.T).astype(dtype)
lora_a = self.lora_a.T.astype(dtype)
weight = weight + lora_b @ lora_a
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
fused_linear.weight = norm_scale[:, None] * weight
if bias:
fused_linear.bias = linear.bias
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = 8,
alpha: float = 16,
dropout: float = 0.0,
scale: float = 10.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale * (alpha / r)
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
self.m = mx.linalg.norm(self.linear.weight, axis=1)
def __call__(self, x):
# Regular LoRA (without a bias)
y = x @ self.linear.weight.T
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
out = y + (self.scale * z).astype(x.dtype)
# Compute the norm of the adapted weights
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
# Remove the norm and scale by the learned magnitude
out = (self.m / denom) * out
if "bias" in self.linear:
out = out + self.linear.bias
return out

View File

@ -9,6 +9,7 @@ import mlx.nn as nn
import mlx.optimizers as opt import mlx.optimizers as opt
from mlx.utils import tree_unflatten from mlx.utils import tree_unflatten
from .dora import DoRALinear
from .lora import LoRALinear from .lora import LoRALinear
@ -36,6 +37,7 @@ def linear_to_lora_layers(
model: nn.Module, model: nn.Module,
num_lora_layers: int, num_lora_layers: int,
config: Dict, config: Dict,
use_dora: bool = False,
): ):
""" """
Convert some of the models linear layers to lora layers. Convert some of the models linear layers to lora layers.
@ -46,6 +48,8 @@ def linear_to_lora_layers(
starting from the last layer. starting from the last layer.
config (dict): More configuration parameters for LoRA, including the config (dict): More configuration parameters for LoRA, including the
rank, alpha, scale, and optional layer keys. rank, alpha, scale, and optional layer keys.
use_dora (bool): If True, uses DoRA instead of LoRA.
Default: ``False``
""" """
num_layers = len(model.layers) num_layers = len(model.layers)
@ -54,8 +58,10 @@ def linear_to_lora_layers(
f"Requested {num_lora_layers} LoRA layers " f"Requested {num_lora_layers} LoRA layers "
f"but the model only has {num_layers} layers." f"but the model only has {num_layers} layers."
) )
cls = DoRALinear if use_dora else LoRALinear
to_lora = lambda lin: LoRALinear.from_linear( def to_lora(lin):
return cls.from_linear(
lin, lin,
r=config["rank"], r=config["rank"],
alpha=config["alpha"], alpha=config["alpha"],
@ -119,7 +125,12 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
with open(adapter_path / "adapter_config.json", "r") as fid: with open(adapter_path / "adapter_config.json", "r") as fid:
config = types.SimpleNamespace(**json.load(fid)) config = types.SimpleNamespace(**json.load(fid))
linear_to_lora_layers(model, config.lora_layers, config.lora_parameters) linear_to_lora_layers(
model,
config.lora_layers,
config.lora_parameters,
getattr(config, "use_dora", False),
)
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False) model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
return model return model

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.13.0" __version__ = "0.13.1"