# Copyright © 2024 Apple Inc. import argparse import math import re import types from pathlib import Path import mlx.nn as nn import mlx.optimizers as optim import numpy as np import yaml from mlx.utils import tree_flatten from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.utils import build_schedule, linear_to_lora_layers from .utils import load, save_config yaml_loader = yaml.SafeLoader yaml_loader.add_implicit_resolver( "tag:yaml.org,2002:float", re.compile( """^(?: [-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)? |[-+]?(?:[0-9][0-9_]*)(?:[eE][-+]?[0-9]+) |\\.[0-9_]+(?:[eE][-+][0-9]+)? |[-+]?[0-9][0-9_]*(?::[0-5]?[0-9])+\\.[0-9_]* |[-+]?\\.(?:inf|Inf|INF) |\\.(?:nan|NaN|NAN))$""", re.X, ), list("-+0123456789."), ) CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, "data": "data/", "seed": 0, "lora_layers": 16, "batch_size": 4, "iters": 1000, "val_batches": 25, "learning_rate": 1e-5, "steps_per_report": 10, "steps_per_eval": 200, "resume_adapter_file": None, "adapter_path": "adapters", "save_every": 100, "test": False, "test_batches": 500, "max_seq_length": 2048, "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, } def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser.add_argument( "--model", help="The path to the local model directory or Hugging Face repo.", ) # Training args parser.add_argument( "--train", action="store_true", help="Do training", ) parser.add_argument( "--data", type=str, help="Directory with {train, valid, test}.jsonl files", ) parser.add_argument( "--lora-layers", type=int, help="Number of layers to fine-tune", ) parser.add_argument("--batch-size", type=int, help="Minibatch size.") parser.add_argument("--iters", type=int, help="Iterations to train for.") parser.add_argument( "--val-batches", type=int, help="Number of validation batches, -1 uses the entire validation set.", ) parser.add_argument("--learning-rate", type=float, help="Adam learning rate.") parser.add_argument( "--steps-per-report", type=int, help="Number of training steps between loss reporting.", ) parser.add_argument( "--steps-per-eval", type=int, help="Number of training steps between validations.", ) parser.add_argument( "--resume-adapter-file", type=str, help="Load path to resume training with the given adapters.", ) parser.add_argument( "--adapter-path", type=str, help="Save/load path for the adapters.", ) parser.add_argument( "--save-every", type=int, help="Save the model every N iterations.", ) parser.add_argument( "--test", action="store_true", help="Evaluate on the test set after training", ) parser.add_argument( "--test-batches", type=int, help="Number of test set batches, -1 uses the entire test set.", ) parser.add_argument( "--max-seq-length", type=int, help="Maximum sequence length.", ) parser.add_argument( "-c", "--config", default=None, help="A YAML configuration file with the training options", ) parser.add_argument( "--grad-checkpoint", action="store_true", help="Use gradient checkpointing to reduce memory use.", ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") return parser def print_trainable_parameters(model): def nparams(m): if isinstance(m, nn.QuantizedLinear): return m.weight.size * (32 // m.bits) return sum(v.size for _, v in tree_flatten(m.parameters())) leaf_modules = tree_flatten( model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) ) total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 trainable_p = ( sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 ) print( f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " f"({trainable_p:.3f}M/{total_p:.3f}M)" ) def run(args, training_callback: TrainingCallback = None): np.random.seed(args.seed) print("Loading pretrained model") model, tokenizer = load(args.model) # Freeze all layers model.freeze() # Convert linear layers to lora layers and unfreeze in the process linear_to_lora_layers(model, args.lora_layers, args.lora_parameters) print_trainable_parameters(model) print("Loading datasets") train_set, valid_set, test_set = load_dataset(args, tokenizer) # Resume training the given adapters. if args.resume_adapter_file is not None: print(f"Loading pretrained adapters from {args.resume_adapter_file}") model.load_weights(args.resume_adapter_file, strict=False) adapter_path = Path(args.adapter_path) adapter_path.mkdir(parents=True, exist_ok=True) save_config(vars(args), adapter_path / "adapter_config.json") adapter_file = adapter_path / "adapters.safetensors" if args.train: print("Training") # init training args training_args = TrainingArgs( batch_size=args.batch_size, iters=args.iters, val_batches=args.val_batches, steps_per_report=args.steps_per_report, steps_per_eval=args.steps_per_eval, steps_per_save=args.save_every, adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, ) model.train() opt = optim.Adam( learning_rate=( build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate ) ) # Train model train( model=model, tokenizer=tokenizer, args=training_args, optimizer=opt, train_dataset=train_set, val_dataset=valid_set, training_callback=training_callback, ) # Load the LoRA adapter weights which we assume should exist by this point if not adapter_file.is_file(): raise ValueError( f"Adapter file {adapter_file} missing. " "Use --train to learn and save the adapters" ) model.load_weights(str(adapter_file), strict=False) if args.test: print("Testing") model.eval() test_loss = evaluate( model=model, dataset=test_set, tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.test_batches, ) test_ppl = math.exp(test_loss) print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") def main(): parser = build_parser() args = parser.parse_args() config = args.config args = vars(args) if config: print("Loading configuration file", config) with open(config, "r") as file: config = yaml.load(file, yaml_loader) # Prefer parameters from command-line arguments for k, v in config.items(): if not args.get(k, None): args[k] = v # Update defaults for unspecified parameters for k, v in CONFIG_DEFAULTS.items(): if not args.get(k, None): args[k] = v run(types.SimpleNamespace(**args)) if __name__ == "__main__": main()