# Copyright © 2024 Apple Inc. import argparse import math import os import re import types from pathlib import Path import mlx.nn as nn import mlx.optimizers as optim import numpy as np import yaml from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import ( build_schedule, linear_to_lora_layers, load_adapters, print_trainable_parameters, ) from .utils import load, save_config yaml_loader = yaml.SafeLoader yaml_loader.add_implicit_resolver( "tag:yaml.org,2002:float", re.compile( """^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) |\\.(?:nan|NaN|NAN))$""", re.X, ), list("-+0123456789."), ) CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, "fine_tune_type": "lora", "data": "data/", "seed": 0, "num_layers": 16, "batch_size": 4, "iters": 1000, "val_batches": 25, "learning_rate": 1e-5, "steps_per_report": 10, "steps_per_eval": 200, "resume_adapter_file": None, "adapter_path": "adapters", "save_every": 100, "test": False, "test_batches": 500, "max_seq_length": 2048, "config": None, "grad_checkpoint": False, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "mask_prompt": False, } def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser.add_argument( "--model", type=str, help="The path to the local model directory or Hugging Face repo.", ) # Training args parser.add_argument( "--train", action="store_true", help="Do training", default=None, ) parser.add_argument( "--data", type=str, help=( "Directory with {train, valid, test}.jsonl files or the name " "of a Hugging Face dataset (e.g., 'mlx-community/wikisql')" ), ) parser.add_argument( "--fine-tune-type", type=str, choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) parser.add_argument( "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", default=None, ) parser.add_argument( "--num-layers", type=int, help="Number of layers to fine-tune. Default is 16, use -1 for all.", ) parser.add_argument("--batch-size", type=int, help="Minibatch size.") parser.add_argument("--iters", type=int, help="Iterations to train for.") parser.add_argument( "--val-batches", type=int, help="Number of validation batches, -1 uses the entire validation set.", ) parser.add_argument("--learning-rate", type=float, help="Adam learning rate.") parser.add_argument( "--steps-per-report", type=int, help="Number of training steps between loss reporting.", ) parser.add_argument( "--steps-per-eval", type=int, help="Number of training steps between validations.", ) parser.add_argument( "--resume-adapter-file", type=str, help="Load path to resume training from the given fine-tuned weights.", ) parser.add_argument( "--adapter-path", type=str, help="Save/load path for the fine-tuned weights.", ) parser.add_argument( "--save-every", type=int, help="Save the model every N iterations.", ) parser.add_argument( "--test", action="store_true", help="Evaluate on the test set after training", default=None, ) parser.add_argument( "--test-batches", type=int, help="Number of test set batches, -1 uses the entire test set.", ) parser.add_argument( "--max-seq-length", type=int, help="Maximum sequence length.", ) parser.add_argument( "-c", "--config", type=str, help="A YAML configuration file with the training options", ) parser.add_argument( "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", default=None, ) parser.add_argument("--seed", type=int, help="The PRNG seed") return parser def train_model( args, model: nn.Module, tokenizer: TokenizerWrapper, train_set, valid_set, training_callback: TrainingCallback = None, ): model.freeze() if args.num_layers > len(model.layers): raise ValueError( f"Requested to train {args.num_layers} layers " f"but the model only has {len(model.layers)} layers." ) if args.fine_tune_type == "full": for l in model.layers[-max(args.num_layers, 0) :]: l.unfreeze() elif args.fine_tune_type in ["lora", "dora"]: # Convert linear layers to lora/dora layers and unfreeze in the process linear_to_lora_layers( model, args.num_layers, args.lora_parameters, use_dora=(args.fine_tune_type == "dora"), ) else: raise ValueError(f"Received unknown fine-tune-type {args.fine_tune_type}") # Resume from weights if provided if args.resume_adapter_file is not None: print(f"Loading fine-tuned weights from {args.resume_adapter_file}") model.load_weights(args.resume_adapter_file, strict=False) print_trainable_parameters(model) adapter_path = Path(args.adapter_path) adapter_path.mkdir(parents=True, exist_ok=True) adapter_file = adapter_path / "adapters.safetensors" save_config(vars(args), adapter_path / "adapter_config.json") # init training args training_args = TrainingArgs( batch_size=args.batch_size, iters=args.iters, val_batches=args.val_batches, steps_per_report=args.steps_per_report, steps_per_eval=args.steps_per_eval, steps_per_save=args.save_every, adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, ) model.train() opt = optim.Adam( learning_rate=( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) ) # Train model train( model=model, tokenizer=tokenizer, args=training_args, optimizer=opt, train_dataset=train_set, val_dataset=valid_set, training_callback=training_callback, ) def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set): model.eval() test_loss = evaluate( model=model, dataset=test_set, tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.test_batches, max_seq_length=args.max_seq_length, ) test_ppl = math.exp(test_loss) print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") def run(args, training_callback: TrainingCallback = None): np.random.seed(args.seed) print("Loading pretrained model") model, tokenizer = load(args.model) print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) if args.test and not args.train: # Allow testing without LoRA layers by providing empty path if args.adapter_path != "": load_adapters(model, args.adapter_path) elif args.train: print("Training") train_model(args, model, tokenizer, train_set, valid_set, training_callback) else: raise ValueError("Must provide at least one of --train or --test") if args.test: print("Testing") evaluate_model(args, model, tokenizer, test_set) def main(): os.environ["TOKENIZERS_PARALLELISM"] = "true" parser = build_parser() args = parser.parse_args() config = args.config args = vars(args) if config: print("Loading configuration file", config) with open(config, "r") as file: config = yaml.load(file, yaml_loader) # Prefer parameters from command-line arguments for k, v in config.items(): if args.get(k, None) is None: args[k] = v # Update defaults for unspecified parameters for k, v in CONFIG_DEFAULTS.items(): if args.get(k, None) is None: args[k] = v run(types.SimpleNamespace(**args)) if __name__ == "__main__": main()