mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Adding full finetuning (#903)
* Adding full model weights finetuning * Updating the LORA.md and ACKNOWLEDGMENTS.md files. * removing --use-dora and --fulll-training and adding --fine-tune-type * some clean up * reformating and fixing dora training * updated CONFIG_DEFAULTS * update config example * update in the config example fie * Update LORA.md * merge and commit * adding argument for dora linear layer * clean up * clean up in the example yaml file * fix * final fix before sending * small addition to re md file * fix for loading the fully trained model by saving all the files and configs correctly * clean up * removing the unnesesairy files * changing lora layers back to 16 * removed max file size * nits * resolve merge * some consistency changes --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -15,9 +15,9 @@ from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import (
|
||||
apply_lora_layers,
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
load_adapters,
|
||||
print_trainable_parameters,
|
||||
)
|
||||
from .utils import load, save_config
|
||||
@@ -41,9 +41,10 @@ yaml_loader.add_implicit_resolver(
|
||||
CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"fine_tune_type": "lora",
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"lora_layers": 16,
|
||||
"num_layers": 16,
|
||||
"batch_size": 4,
|
||||
"iters": 1000,
|
||||
"val_batches": 25,
|
||||
@@ -58,7 +59,6 @@ CONFIG_DEFAULTS = {
|
||||
"max_seq_length": 2048,
|
||||
"lr_schedule": None,
|
||||
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
|
||||
"use_dora": False,
|
||||
}
|
||||
|
||||
|
||||
@@ -82,7 +82,14 @@ def build_parser():
|
||||
help="Directory with {train, valid, test}.jsonl files",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--lora-layers",
|
||||
"--fine-tune-type",
|
||||
type=str,
|
||||
choices=["lora", "dora", "full"],
|
||||
default="lora",
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
help="Number of layers to fine-tune. Default is 16, use -1 for all.",
|
||||
)
|
||||
@@ -107,12 +114,12 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--resume-adapter-file",
|
||||
type=str,
|
||||
help="Load path to resume training with the given adapters.",
|
||||
help="Load path to resume training from the given fine-tuned weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--adapter-path",
|
||||
type=str,
|
||||
help="Save/load path for the adapters.",
|
||||
help="Save/load path for the fine-tuned weights.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-every",
|
||||
@@ -148,9 +155,6 @@ def build_parser():
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--seed", type=int, default=None, help="The PRNG seed")
|
||||
parser.add_argument(
|
||||
"--use-dora", action="store_true", default=None, help="Use DoRA to finetune."
|
||||
)
|
||||
return parser
|
||||
|
||||
|
||||
@@ -162,21 +166,31 @@ def train_model(
|
||||
valid_set,
|
||||
training_callback: TrainingCallback = None,
|
||||
):
|
||||
# Freeze all layers
|
||||
model.freeze()
|
||||
if args.fine_tune_type == "full":
|
||||
for l in model.layers[-min(args.num_layers, 0) :]:
|
||||
l.unfreeze()
|
||||
elif args.fine_tune_type in ["lora", "dora"]:
|
||||
# Convert linear layers to lora/dora layers and unfreeze in the process
|
||||
linear_to_lora_layers(
|
||||
model,
|
||||
args.num_layers,
|
||||
args.lora_parameters,
|
||||
use_dora=(args.fine_tune_type == "dora"),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}")
|
||||
|
||||
# Convert linear layers to lora layers and unfreeze in the process
|
||||
linear_to_lora_layers(model, args.lora_layers, args.lora_parameters, args.use_dora)
|
||||
|
||||
# Resume training the given adapters.
|
||||
# Resume from weights if provided
|
||||
if args.resume_adapter_file is not None:
|
||||
print(f"Loading pretrained adapters from {args.resume_adapter_file}")
|
||||
print(f"Loading fine-tuned weights from {args.resume_adapter_file}")
|
||||
model.load_weights(args.resume_adapter_file, strict=False)
|
||||
|
||||
print_trainable_parameters(model)
|
||||
|
||||
adapter_path = Path(args.adapter_path)
|
||||
adapter_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
adapter_file = adapter_path / "adapters.safetensors"
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
@@ -240,7 +254,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
if args.test and not args.train:
|
||||
# Allow testing without LoRA layers by providing empty path
|
||||
if args.adapter_path != "":
|
||||
apply_lora_layers(model, args.adapter_path)
|
||||
load_adapters(model, args.adapter_path)
|
||||
|
||||
elif args.train:
|
||||
print("Training")
|
||||
|
Reference in New Issue
Block a user