import argparse import json import math from pathlib import Path import mlx.optimizers as optim import numpy as np from mlx.utils import tree_flatten from .models import llama, mixtral, phi2 from .tuner.lora import LoRALinear from .tuner.trainer import TrainingArgs, evaluate, train from .utils import generate, load SUPPORTED_MODELS = [llama.Model, mixtral.Model, phi2.Model] def build_parser(): parser = argparse.ArgumentParser(description="LoRA or QLoRA finetuning.") parser.add_argument( "--model", default="mlx_model", help="The path to the local model directory or Hugging Face repo.", ) # Generation args parser.add_argument( "--max-tokens", "-m", type=int, default=100, help="The maximum number of tokens to generate", ) parser.add_argument( "--temp", type=float, default=0.8, help="The sampling temperature" ) parser.add_argument( "--prompt", "-p", type=str, help="The prompt for generation", default=None, ) # Training args parser.add_argument( "--train", action="store_true", help="Do training", ) parser.add_argument( "--data", type=str, default="data/", help="Directory with {train, valid, test}.jsonl files", ) parser.add_argument( "--lora-layers", type=int, default=16, help="Number of layers to fine-tune", ) parser.add_argument("--batch-size", type=int, default=4, help="Minibatch size.") parser.add_argument( "--iters", type=int, default=1000, help="Iterations to train for." ) parser.add_argument( "--val-batches", type=int, default=25, help="Number of validation batches, -1 uses the entire validation set.", ) parser.add_argument( "--learning-rate", type=float, default=1e-5, help="Adam learning rate." ) parser.add_argument( "--steps-per-report", type=int, default=10, help="Number of training steps between loss reporting.", ) parser.add_argument( "--steps-per-eval", type=int, default=200, help="Number of training steps between validations.", ) parser.add_argument( "--resume-adapter-file", type=str, default=None, help="Load path to resume training with the given adapter weights.", ) parser.add_argument( "--adapter-file", type=str, default="adapters.npz", help="Save/load path for the trained adapter weights.", ) parser.add_argument( "--save-every", type=int, default=100, help="Save the model every N iterations.", ) parser.add_argument( "--test", action="store_true", help="Evaluate on the test set after training", ) parser.add_argument( "--test-batches", type=int, default=500, help="Number of test set batches, -1 uses the entire test set.", ) parser.add_argument("--seed", type=int, default=0, help="The PRNG seed") return parser class Dataset: """ Light-weight wrapper to hold lines from a jsonl file """ def __init__(self, path: Path, key: str = "text"): if not path.exists(): self._data = None else: with open(path, "r") as fid: self._data = [json.loads(l) for l in fid] self._key = key def __getitem__(self, idx: int): return self._data[idx][self._key] def __len__(self): if self._data is None: return 0 return len(self._data) def load_dataset(args): names = ("train", "valid", "test") train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names) if args.train and len(train) == 0: raise ValueError( "Training set not found or empty. Must provide training set for fine-tuning." ) if args.train and len(valid) == 0: raise ValueError( "Validation set not found or empty. Must provide validation set for fine-tuning." ) if args.test and len(test) == 0: raise ValueError( "Test set not found or empty. Must provide test set for evaluation." ) return train, valid, test if __name__ == "__main__": parser = build_parser() args = parser.parse_args() np.random.seed(args.seed) print("Loading pretrained model") model, tokenizer = load(args.model) if model.__class__ not in SUPPORTED_MODELS: raise ValueError( f"Model {model.__class__} not supported. " f"Supported models: { SUPPORTED_MODELS}" ) # Freeze all layers other than LORA linears model.freeze() for l in model.model.layers[len(model.model.layers) - args.lora_layers :]: l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) if hasattr(l, "block_sparse_moe"): l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 print(f"Total parameters {p:.3f}M") p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 print(f"Trainable parameters {p:.3f}M") print("Loading datasets") train_set, valid_set, test_set = load_dataset(args) # Resume training the given adapters. if args.resume_adapter_file is not None: print(f"Loading pretrained adapters from {args.resume_adapter_file}") model.load_weights(args.resume_adapter_file, strict=False) # init training args trainingArgs = TrainingArgs( batch_size=args.batch_size, iters=args.iters, val_batches=args.val_batches, steps_per_report=args.steps_per_report, steps_per_eval=args.steps_per_eval, steps_per_save=args.save_every, adapter_file=args.adapter_file, ) if args.train: print("Training") model.train() opt = optim.Adam(learning_rate=args.learning_rate) # Train model train( model=model, tokenizer=tokenizer, args=trainingArgs, optimizer=opt, train_dataset=train_set, val_dataset=valid_set, ) # Load the LoRA adapter weights which we assume should exist by this point if not Path(args.adapter_file).is_file(): raise ValueError( f"Adapter file {args.adapter_file} missing. " "Use --train to learn and save the adapters.npz." ) model.load_weights(args.adapter_file, strict=False) if args.test: print("Testing") model.eval() test_loss = evaluate( model=model, dataset=test_set, tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.test_batches, ) test_ppl = math.exp(test_loss) print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") if args.prompt is not None: print("Generating") model.eval() generate( model=model, tokenizer=tokenizer, temp=args.temp, max_tokens=args.max_tokens, prompt=args.prompt, verbose=True, )