Merge branch 'main' into adding-dpo-training

This commit is contained in:
Gökdeniz Gülmez
2025-02-10 10:55:39 +01:00
committed by GitHub
17 changed files with 497 additions and 145 deletions

View File

@@ -102,6 +102,14 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument(
"--training-mode",
type=str,
@@ -248,6 +256,7 @@ def train_model(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
if args.training_mode == "dpo":
training_args = DPOTrainingArgs(