Merge branch 'adding-dpo-training' of https://github.com/Goekdeniz-Guelmez/mlx-examples into adding-dpo-training

This commit is contained in:
Goekdeniz-Guelmez
2025-02-10 10:56:57 +01:00
19 changed files with 499 additions and 147 deletions

View File

@@ -102,6 +102,14 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=False,
)
parser.add_argument(
"--training-mode",
type=str,
@@ -248,6 +256,7 @@ def train_model(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
if args.training_mode == "dpo":
training_args = DPOTrainingArgs(