Changed the switch to set opt_class

This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 09:40:36 +01:00
parent 60df71bcbc
commit 64ed426518

View File

@ -45,23 +45,14 @@ CONFIG_DEFAULTS = {
"fine_tune_type": "lora", "fine_tune_type": "lora",
"optimizer": "adam", "optimizer": "adam",
"optimizer_config": { "optimizer_config": {
"adam": { "adam": {"bias_correction": False},
"betas": [0.9, 0.999], "adamw": {"weight_decay": 0.01, "bias_correction": False},
"eps": 1e-8,
"bias_correction": False
},
"adamw": {
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01,
"bias_correction": False
},
"muon": { "muon": {
"momentum": 0.95, "momentum": 0.95,
"weight_decay": 0.01, "weight_decay": 0.01,
"nesterov": True, "nesterov": True,
"ns_steps": 5 "ns_steps": 5,
} },
}, },
"data": "data/", "data": "data/",
"seed": 0, "seed": 0,
@ -116,7 +107,7 @@ def build_parser():
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument( parser.add_argument(
'--optimizer', "--optimizer",
type=str, type=str,
choices=["adam", "adamw", "muon"], choices=["adam", "adamw", "muon"],
default="adam", default="adam",
@ -262,31 +253,16 @@ def train_model(
optimizer_config = args.optimizer_config.get(optimizer_name, {}) optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam": if optimizer_name == "adam":
opt = optim.Adam( opt_class = optim.Adam
learning_rate=lr,
betas=optimizer_config.get("betas", [0.9, 0.999]),
eps=optimizer_config.get("eps", 1e-8),
bias_correction=optimizer_config.get("bias_correction", False)
)
elif optimizer_name == "adamw": elif optimizer_name == "adamw":
opt = optim.AdamW( opt_class = optim.AdamW
learning_rate=lr,
betas=optimizer_config.get("betas", [0.9, 0.999]),
eps=optimizer_config.get("eps", 1e-8),
weight_decay=optimizer_config.get("weight_decay", 0.01),
bias_correction=optimizer_config.get("bias_correction", False)
)
elif optimizer_name == "muon": elif optimizer_name == "muon":
opt = optim.Muon( opt_class = optim.Muon
learning_rate=lr,
momentum=optimizer_config.get("momentum", 0.95),
weight_decay=optimizer_config.get("weight_decay", 0.01),
nesterov=optimizer_config.get("nesterov", True),
ns_steps=optimizer_config.get("ns_steps", 5)
)
else: else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}") raise ValueError(f"Unsupported optimizer: {optimizer_name}")
opt = opt_class(learning_rate=lr, **optimizer_config)
# Train model # Train model
train( train(
model=model, model=model,