mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Merge branch 'main' into adding-GRPO-training
This commit is contained in:
@@ -105,6 +105,14 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=False,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--training-mode",
|
||||
type=str,
|
||||
@@ -274,6 +282,7 @@ def train_model(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
|
||||
# Train model
|
||||
if args.training_mode == "grpo":
|
||||
training_args = GRPOTrainingArgs(
|
||||
|
||||
Reference in New Issue
Block a user