mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
support dora finetune in mlx-examples/llms/mlx_lm (#779)
* support dora finetune * solve problems in lora.py and tuner.utils.py * add use_dora (bool) in functions of load adapters * delete all unsupported quantization code and fix all the calculate problems in mlx_lm/tuner/dora.py * Using stop_gradient to prevent gradients from flowing through ‘norm’ during backpropagation * set DEFAULT_USE_DORA in mlx_lm/generate.py * add annotation for all the use_dora * mlx_lm/fuse.py support fuse dora layers and fix a bug of to_linear() in mlx_lm/tuner/dora.py * simplify code of juding type of a fused layer in mlx_lm/fuse.py * add use_dora in mlx_lm/fuse.py when apply_lora_layers() * style + nits * style + nits * more updates --------- Co-authored-by: chenyifei08 <chenyifei08@baidu.com> Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
69181e0058
commit
42458914c8
@ -51,6 +51,9 @@ max_seq_length: 2048
|
|||||||
# Use gradient checkpointing to reduce memory use.
|
# Use gradient checkpointing to reduce memory use.
|
||||||
grad_checkpoint: false
|
grad_checkpoint: false
|
||||||
|
|
||||||
|
# Use DoRA instead of LoRA.
|
||||||
|
use_dora: false
|
||||||
|
|
||||||
# LoRA parameters can only be specified in a config file
|
# LoRA parameters can only be specified in a config file
|
||||||
lora_parameters:
|
lora_parameters:
|
||||||
# The layer keys to apply LoRA to.
|
# The layer keys to apply LoRA to.
|
||||||
|
@ -6,6 +6,7 @@ from pathlib import Path
|
|||||||
from mlx.utils import tree_flatten, tree_unflatten
|
from mlx.utils import tree_flatten, tree_unflatten
|
||||||
|
|
||||||
from .gguf import convert_to_gguf
|
from .gguf import convert_to_gguf
|
||||||
|
from .tuner.dora import DoRALinear
|
||||||
from .tuner.lora import LoRALinear
|
from .tuner.lora import LoRALinear
|
||||||
from .tuner.utils import apply_lora_layers, dequantize
|
from .tuner.utils import apply_lora_layers, dequantize
|
||||||
from .utils import (
|
from .utils import (
|
||||||
@ -18,7 +19,9 @@ from .utils import (
|
|||||||
|
|
||||||
|
|
||||||
def parse_arguments() -> argparse.Namespace:
|
def parse_arguments() -> argparse.Namespace:
|
||||||
parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.")
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Fuse fine-tuned adapters into the base model."
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="mlx_model",
|
default="mlx_model",
|
||||||
@ -79,7 +82,7 @@ def main() -> None:
|
|||||||
fused_linears = [
|
fused_linears = [
|
||||||
(n, m.to_linear())
|
(n, m.to_linear())
|
||||||
for n, m in model.named_modules()
|
for n, m in model.named_modules()
|
||||||
if isinstance(m, LoRALinear)
|
if isinstance(m, (LoRALinear, DoRALinear))
|
||||||
]
|
]
|
||||||
|
|
||||||
model.update_modules(tree_unflatten(fused_linears))
|
model.update_modules(tree_unflatten(fused_linears))
|
||||||
|
@ -123,7 +123,9 @@ def main():
|
|||||||
tokenizer_config["eos_token"] = args.eos_token
|
tokenizer_config["eos_token"] = args.eos_token
|
||||||
|
|
||||||
model, tokenizer = load(
|
model, tokenizer = load(
|
||||||
args.model, adapter_path=args.adapter_path, tokenizer_config=tokenizer_config
|
args.model,
|
||||||
|
adapter_path=args.adapter_path,
|
||||||
|
tokenizer_config=tokenizer_config,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.use_default_chat_template:
|
if args.use_default_chat_template:
|
||||||
|
@ -54,6 +54,7 @@ CONFIG_DEFAULTS = {
|
|||||||
"max_seq_length": 2048,
|
"max_seq_length": 2048,
|
||||||
"lr_schedule": None,
|
"lr_schedule": None,
|
||||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||||
|
"use_dora": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -69,6 +70,7 @@ def build_parser():
|
|||||||
"--train",
|
"--train",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Do training",
|
help="Do training",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--data",
|
"--data",
|
||||||
@ -117,6 +119,7 @@ def build_parser():
|
|||||||
"--test",
|
"--test",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Evaluate on the test set after training",
|
help="Evaluate on the test set after training",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--test-batches",
|
"--test-batches",
|
||||||
@ -138,8 +141,12 @@ def build_parser():
|
|||||||
"--grad-checkpoint",
|
"--grad-checkpoint",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Use gradient checkpointing to reduce memory use.",
|
help="Use gradient checkpointing to reduce memory use.",
|
||||||
|
default=None,
|
||||||
)
|
)
|
||||||
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
|
||||||
|
parser.add_argument(
|
||||||
|
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
|
||||||
|
)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
@ -175,16 +182,20 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
adapter_file = adapter_path / "adapters.safetensors"
|
adapter_file = adapter_path / "adapters.safetensors"
|
||||||
|
|
||||||
if args.test and not args.train:
|
if args.test and not args.train:
|
||||||
|
# Allow testing without LoRA layers by providing empty path
|
||||||
|
if args.adapter_path != "":
|
||||||
apply_lora_layers(model, adapter_path)
|
apply_lora_layers(model, adapter_path)
|
||||||
|
elif args.train:
|
||||||
else:
|
|
||||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||||
|
|
||||||
# Convert linear layers to lora layers and unfreeze in the process
|
# Convert linear layers to lora layers and unfreeze in the process
|
||||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters)
|
linear_to_lora_layers(
|
||||||
|
model, args.lora_layers, args.lora_parameters, args.use_dora
|
||||||
|
)
|
||||||
print_trainable_parameters(model)
|
print_trainable_parameters(model)
|
||||||
|
else:
|
||||||
|
raise ValueError("Must provide at least one of --train or --test")
|
||||||
|
|
||||||
print("Loading datasets")
|
print("Loading datasets")
|
||||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||||
@ -257,12 +268,12 @@ def main():
|
|||||||
config = yaml.load(file, yaml_loader)
|
config = yaml.load(file, yaml_loader)
|
||||||
# Prefer parameters from command-line arguments
|
# Prefer parameters from command-line arguments
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
if not args.get(k, None):
|
if args.get(k, None) is not None:
|
||||||
args[k] = v
|
args[k] = v
|
||||||
|
|
||||||
# Update defaults for unspecified parameters
|
# Update defaults for unspecified parameters
|
||||||
for k, v in CONFIG_DEFAULTS.items():
|
for k, v in CONFIG_DEFAULTS.items():
|
||||||
if not args.get(k, None):
|
if args.get(k, None) is None:
|
||||||
args[k] = v
|
args[k] = v
|
||||||
run(types.SimpleNamespace(**args))
|
run(types.SimpleNamespace(**args))
|
||||||
|
|
||||||
|
98
llms/mlx_lm/tuner/dora.py
Normal file
98
llms/mlx_lm/tuner/dora.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import math
|
||||||
|
|
||||||
|
import mlx.core as mx
|
||||||
|
import mlx.nn as nn
|
||||||
|
|
||||||
|
|
||||||
|
class DoRALinear(nn.Module):
|
||||||
|
@staticmethod
|
||||||
|
def from_linear(
|
||||||
|
linear: nn.Linear,
|
||||||
|
r: int = 8,
|
||||||
|
alpha: float = 16,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scale: float = 10.0,
|
||||||
|
):
|
||||||
|
# TODO support quantized weights in DoRALinear
|
||||||
|
output_dims, input_dims = linear.weight.shape
|
||||||
|
if isinstance(linear, nn.QuantizedLinear):
|
||||||
|
raise ValueError("DoRALinear does not yet support quantization.")
|
||||||
|
dora_lin = DoRALinear(
|
||||||
|
input_dims=input_dims,
|
||||||
|
output_dims=output_dims,
|
||||||
|
r=r,
|
||||||
|
alpha=alpha,
|
||||||
|
dropout=dropout,
|
||||||
|
scale=scale,
|
||||||
|
)
|
||||||
|
dora_lin.linear = linear
|
||||||
|
return dora_lin
|
||||||
|
|
||||||
|
def to_linear(self, de_quantize: bool = False):
|
||||||
|
linear = self.linear
|
||||||
|
bias = "bias" in linear
|
||||||
|
weight = linear.weight
|
||||||
|
|
||||||
|
# Use the same type as the linear weight if not quantized
|
||||||
|
dtype = weight.dtype
|
||||||
|
|
||||||
|
output_dims, input_dims = weight.shape
|
||||||
|
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||||
|
|
||||||
|
lora_b = (self.scale * self.lora_b.T).astype(dtype)
|
||||||
|
lora_a = self.lora_a.T.astype(dtype)
|
||||||
|
weight = weight + lora_b @ lora_a
|
||||||
|
norm_scale = self.m / mx.linalg.norm(weight, axis=1)
|
||||||
|
fused_linear.weight = norm_scale[:, None] * weight
|
||||||
|
|
||||||
|
if bias:
|
||||||
|
fused_linear.bias = linear.bias
|
||||||
|
return fused_linear
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
input_dims: int,
|
||||||
|
output_dims: int,
|
||||||
|
r: int = 8,
|
||||||
|
alpha: float = 16,
|
||||||
|
dropout: float = 0.0,
|
||||||
|
scale: float = 10.0,
|
||||||
|
bias: bool = False,
|
||||||
|
):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
# Regular linear layer weights
|
||||||
|
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
|
||||||
|
self.dropout = nn.Dropout(p=dropout)
|
||||||
|
|
||||||
|
# Scale for low-rank update
|
||||||
|
self.scale = scale * (alpha / r)
|
||||||
|
|
||||||
|
# Low rank lora weights
|
||||||
|
scale = 1 / math.sqrt(input_dims)
|
||||||
|
self.lora_a = mx.random.uniform(
|
||||||
|
low=-scale,
|
||||||
|
high=scale,
|
||||||
|
shape=(input_dims, r),
|
||||||
|
)
|
||||||
|
self.lora_b = mx.zeros(shape=(r, output_dims))
|
||||||
|
self.m = mx.linalg.norm(self.linear.weight, axis=1)
|
||||||
|
|
||||||
|
def __call__(self, x):
|
||||||
|
# Regular LoRA (without a bias)
|
||||||
|
y = x @ self.linear.weight.T
|
||||||
|
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
|
||||||
|
out = y + (self.scale * z).astype(x.dtype)
|
||||||
|
|
||||||
|
# Compute the norm of the adapted weights
|
||||||
|
adapted = self.linear.weight + (self.scale * self.lora_b.T) @ self.lora_a.T
|
||||||
|
denom = mx.stop_gradient(mx.linalg.norm(adapted, axis=1))
|
||||||
|
|
||||||
|
# Remove the norm and scale by the learned magnitude
|
||||||
|
out = (self.m / denom) * out
|
||||||
|
|
||||||
|
if "bias" in self.linear:
|
||||||
|
out = out + self.linear.bias
|
||||||
|
return out
|
@ -9,6 +9,7 @@ import mlx.nn as nn
|
|||||||
import mlx.optimizers as opt
|
import mlx.optimizers as opt
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
|
|
||||||
|
from .dora import DoRALinear
|
||||||
from .lora import LoRALinear
|
from .lora import LoRALinear
|
||||||
|
|
||||||
|
|
||||||
@ -36,6 +37,7 @@ def linear_to_lora_layers(
|
|||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
num_lora_layers: int,
|
num_lora_layers: int,
|
||||||
config: Dict,
|
config: Dict,
|
||||||
|
use_dora: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Convert some of the models linear layers to lora layers.
|
Convert some of the models linear layers to lora layers.
|
||||||
@ -46,6 +48,8 @@ def linear_to_lora_layers(
|
|||||||
starting from the last layer.
|
starting from the last layer.
|
||||||
config (dict): More configuration parameters for LoRA, including the
|
config (dict): More configuration parameters for LoRA, including the
|
||||||
rank, alpha, scale, and optional layer keys.
|
rank, alpha, scale, and optional layer keys.
|
||||||
|
use_dora (bool): If True, uses DoRA instead of LoRA.
|
||||||
|
Default: ``False``
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_layers = len(model.layers)
|
num_layers = len(model.layers)
|
||||||
@ -54,8 +58,10 @@ def linear_to_lora_layers(
|
|||||||
f"Requested {num_lora_layers} LoRA layers "
|
f"Requested {num_lora_layers} LoRA layers "
|
||||||
f"but the model only has {num_layers} layers."
|
f"but the model only has {num_layers} layers."
|
||||||
)
|
)
|
||||||
|
cls = DoRALinear if use_dora else LoRALinear
|
||||||
|
|
||||||
to_lora = lambda lin: LoRALinear.from_linear(
|
def to_lora(lin):
|
||||||
|
return cls.from_linear(
|
||||||
lin,
|
lin,
|
||||||
r=config["rank"],
|
r=config["rank"],
|
||||||
alpha=config["alpha"],
|
alpha=config["alpha"],
|
||||||
@ -119,7 +125,12 @@ def apply_lora_layers(model: nn.Module, adapter_path: str) -> nn.Module:
|
|||||||
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}")
|
||||||
with open(adapter_path / "adapter_config.json", "r") as fid:
|
with open(adapter_path / "adapter_config.json", "r") as fid:
|
||||||
config = types.SimpleNamespace(**json.load(fid))
|
config = types.SimpleNamespace(**json.load(fid))
|
||||||
linear_to_lora_layers(model, config.lora_layers, config.lora_parameters)
|
linear_to_lora_layers(
|
||||||
|
model,
|
||||||
|
config.lora_layers,
|
||||||
|
config.lora_parameters,
|
||||||
|
getattr(config, "use_dora", False),
|
||||||
|
)
|
||||||
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
model.load_weights(str(adapter_path / "adapters.safetensors"), strict=False)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.13.0"
|
__version__ = "0.13.1"
|
||||||
|
Loading…
Reference in New Issue
Block a user