fix optimizer

This commit is contained in:
Goekdeniz-Guelmez 2025-03-08 22:36:20 +01:00
parent 700c3ef5cc
commit 73cc094681

View File

@ -255,11 +255,21 @@ def train_model(
save_config(vars(args), adapter_path / "adapter_config.json") save_config(vars(args), adapter_path / "adapter_config.json")
model.train() model.train()
opt = optim.Adam( # need to correct that part
learning_rate=( # Initialize the selected optimizer
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
) optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam":
opt_class = optim.Adam
elif optimizer_name == "adamw":
opt_class = optim.AdamW
else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
opt = opt_class(learning_rate=lr, **optimizer_config)
# Train model based on training mode # Train model based on training mode
if args.training_mode == "orpo": if args.training_mode == "orpo":
@ -298,21 +308,6 @@ def train_model(
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint grad_checkpoint=args.grad_checkpoint
) )
# Initialize the selected optimizer
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam":
opt_class = optim.Adam
elif optimizer_name == "adamw":
opt_class = optim.AdamW
else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
opt = opt_class(learning_rate=lr, **optimizer_config)
train( train(
model=model, model=model,