diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 43f508c3..153bc49d 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -94,6 +94,15 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) + + parser.add_argument( + "--mask-inputs", + dest="mask_inputs", + action="store_true", + help="Whether to mask the inputs when training. Default is False.", + default=False, + ) + parser.add_argument( "--num-layers", type=int, @@ -172,6 +181,13 @@ def train_model( valid_set, training_callback: TrainingCallback = None, ): + from .tuner.trainer import ( + default_loss, + input_masked_loss, + iterate_batches, + iterate_delineated_batches, + ) + model.freeze() if args.fine_tune_type == "full": for l in model.layers[-min(args.num_layers, 0) :]: @@ -228,6 +244,10 @@ def train_model( train_dataset=train_set, val_dataset=valid_set, training_callback=training_callback, + iterate_batches=( + iterate_delineated_batches if args.mask_inputs else iterate_batches + ), + loss=input_masked_loss if args.mask_inputs else default_loss, )