From 4b88c33a26fef48dcadb6e4a950bd49fc6721faf Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Tue, 5 Nov 2024 19:10:01 -0500 Subject: [PATCH] Updates CL lora tuner with input masking that uses default_loss (and iterate_batches) by default. --- llms/mlx_lm/lora.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index c96e75a7..226919c9 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -91,6 +91,15 @@ def build_parser(): default="lora", help="Type of fine-tuning to perform: lora, dora, or full.", ) + + parser.add_argument( + "--mask-inputs", + dest="mask_inputs", + action="store_true", + help="Whether to mask the inputs when training. Default is False.", + default=False, + ) + parser.add_argument( "--num-layers", type=int, @@ -169,6 +178,13 @@ def train_model( valid_set, training_callback: TrainingCallback = None, ): + from .tuner.trainer import ( + default_loss, + input_masked_loss, + iterate_batches, + iterate_delineated_batches, + ) + model.freeze() if args.fine_tune_type == "full": for l in model.layers[-min(args.num_layers, 0) :]: @@ -225,6 +241,10 @@ def train_model( train_dataset=train_set, val_dataset=valid_set, training_callback=training_callback, + iterate_batches=( + iterate_delineated_batches if args.mask_inputs else iterate_batches + ), + loss=input_masked_loss if args.mask_inputs else default_loss, )