fix optimizer

This commit is contained in:
Goekdeniz-Guelmez 2025-03-08 22:36:20 +01:00
parent 700c3ef5cc
commit 73cc094681

View File

@ -255,11 +255,21 @@ def train_model(
save_config(vars(args), adapter_path / "adapter_config.json")
model.train()
opt = optim.Adam( # need to correct that part
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Initialize the selected optimizer
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam":
opt_class = optim.Adam
elif optimizer_name == "adamw":
opt_class = optim.AdamW
else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
opt = opt_class(learning_rate=lr, **optimizer_config)
# Train model based on training mode
if args.training_mode == "orpo":
@ -299,21 +309,6 @@ def train_model(
grad_checkpoint=args.grad_checkpoint
)
# Initialize the selected optimizer
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam":
opt_class = optim.Adam
elif optimizer_name == "adamw":
opt_class = optim.AdamW
else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
opt = opt_class(learning_rate=lr, **optimizer_config)
train(
model=model,
tokenizer=tokenizer,