Updates CL lora tuner with input masking that uses default_loss (and iterate_batches) by default.

This commit is contained in:
Chime Ogbuji 2024-11-05 19:10:01 -05:00 committed by Awni Hannun
parent 84fc1bde48
commit 27cd361d76

View File

@ -94,6 +94,15 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--mask-inputs",
dest="mask_inputs",
action="store_true",
help="Whether to mask the inputs when training. Default is False.",
default=False,
)
parser.add_argument(
"--num-layers",
type=int,
@ -172,6 +181,13 @@ def train_model(
valid_set,
training_callback: TrainingCallback = None,
):
from .tuner.trainer import (
default_loss,
input_masked_loss,
iterate_batches,
iterate_delineated_batches,
)
model.freeze()
if args.fine_tune_type == "full":
for l in model.layers[-min(args.num_layers, 0) :]:
@ -228,6 +244,10 @@ def train_model(
train_dataset=train_set,
val_dataset=valid_set,
training_callback=training_callback,
iterate_batches=(
iterate_delineated_batches if args.mask_inputs else iterate_batches
),
loss=input_masked_loss if args.mask_inputs else default_loss,
)