mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
cleaning up some namings
This commit is contained in:
@@ -67,8 +67,7 @@ CONFIG_DEFAULTS = {
|
||||
"beta": 0.1,
|
||||
"dpo_loss_type": "sigmoid",
|
||||
"delta": 50.0,
|
||||
"reference_model_path": None,
|
||||
"train_bias_only": False,
|
||||
"reference_model_path": None
|
||||
}
|
||||
|
||||
|
||||
@@ -173,12 +172,35 @@ def build_parser():
|
||||
help="Use gradient checkpointing to reduce memory use.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument("--beta", type=float)
|
||||
parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"])
|
||||
parser.add_argument("--delta", type=float)
|
||||
parser.add_argument("--reference-model-path", type=str)
|
||||
parser.add_argument("--train-bias-only", action="store_true")
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
|
||||
# DPO args
|
||||
parser.add_argument(
|
||||
"--beta",
|
||||
type=float,
|
||||
help="Temperature parameter for DPO training.",
|
||||
default=0.1
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dpo-loss-type",
|
||||
type=str,
|
||||
help="DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'.",
|
||||
choices=["sigmoid", "hinge", "ipo", "dpop"],
|
||||
default="sigmoid"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--delta",
|
||||
type=float,
|
||||
help="Delta parameter for DPOP loss type.",
|
||||
default=50.0
|
||||
)
|
||||
parser.add_argument(
|
||||
"--reference-model-path",
|
||||
type=str,
|
||||
help="Path to reference model weights. If None, uses the same model.",
|
||||
default=None
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user