mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
Merge branch 'main' into adding-dpo-training
This commit is contained in:
@@ -45,6 +45,11 @@ CONFIG_DEFAULTS = {
|
||||
"train": False,
|
||||
"fine_tune_type": "lora",
|
||||
"training_mode": "normal",
|
||||
"optimizer": "adam",
|
||||
"optimizer_config": {
|
||||
"adam": {},
|
||||
"adamw": {},
|
||||
},
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"num_layers": 16,
|
||||
@@ -102,14 +107,19 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
choices=["adam", "adamw"],
|
||||
default=None,
|
||||
help="Optimizer to use for training: adam or adamw",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--training-mode",
|
||||
type=str,
|
||||
@@ -257,11 +267,21 @@ def train_model(
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(
|
||||
learning_rate=(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize the selected optimizer
|
||||
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
|
||||
optimizer_name = args.optimizer.lower()
|
||||
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||
|
||||
if optimizer_name == "adam":
|
||||
opt_class = optim.Adam
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = optim.AdamW
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||
|
||||
opt = opt_class(learning_rate=lr, **optimizer_config)
|
||||
|
||||
# Train model
|
||||
if args.training_mode == "dpo":
|
||||
|
||||
Reference in New Issue
Block a user