mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-11 11:48:39 +08:00
fix optimizer
This commit is contained in:
parent
700c3ef5cc
commit
73cc094681
@ -255,11 +255,21 @@ def train_model(
|
||||
save_config(vars(args), adapter_path / "adapter_config.json")
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam( # need to correct that part
|
||||
learning_rate=(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize the selected optimizer
|
||||
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
|
||||
optimizer_name = args.optimizer.lower()
|
||||
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||
|
||||
if optimizer_name == "adam":
|
||||
opt_class = optim.Adam
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = optim.AdamW
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||
|
||||
opt = opt_class(learning_rate=lr, **optimizer_config)
|
||||
|
||||
# Train model based on training mode
|
||||
if args.training_mode == "orpo":
|
||||
@ -298,21 +308,6 @@ def train_model(
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint
|
||||
)
|
||||
|
||||
# Initialize the selected optimizer
|
||||
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
|
||||
optimizer_name = args.optimizer.lower()
|
||||
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||
|
||||
if optimizer_name == "adam":
|
||||
opt_class = optim.Adam
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = optim.AdamW
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||
|
||||
opt = opt_class(learning_rate=lr, **optimizer_config)
|
||||
|
||||
train(
|
||||
model=model,
|
||||
|
Loading…
Reference in New Issue
Block a user