adding more customized YAML configuartion

This commit is contained in:
Goekdeniz-Guelmez 2025-03-01 15:00:29 +01:00
parent b0a2edbcf3
commit eed093b0ec

View File

@ -44,6 +44,25 @@ CONFIG_DEFAULTS = {
"train": False,
"fine_tune_type": "lora",
"optimizer": "adam",
"optimizer_config": {
"adam": {
"betas": [0.9, 0.999],
"eps": 1e-8,
"bias_correction": False
},
"adamw": {
"betas": [0.9, 0.999],
"eps": 1e-8,
"weight_decay": 0.01,
"bias_correction": False
},
"muon": {
"momentum": 0.95,
"weight_decay": 0.01,
"nesterov": True,
"ns_steps": 5
}
},
"data": "data/",
"seed": 0,
"num_layers": 16,
@ -239,14 +258,34 @@ def train_model(
# Initialize the selected optimizer
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
if args.optimizer.lower() == "adam":
opt = optim.Adam(learning_rate=lr)
elif args.optimizer.lower() == "adamw":
opt = optim.AdamW(learning_rate=lr)
elif args.optimizer.lower() == "muon":
opt = optim.Muon(learning_rate=lr)
optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam":
opt = optim.Adam(
learning_rate=lr,
betas=optimizer_config.get("betas", [0.9, 0.999]),
eps=optimizer_config.get("eps", 1e-8),
bias_correction=optimizer_config.get("bias_correction", False)
)
elif optimizer_name == "adamw":
opt = optim.AdamW(
learning_rate=lr,
betas=optimizer_config.get("betas", [0.9, 0.999]),
eps=optimizer_config.get("eps", 1e-8),
weight_decay=optimizer_config.get("weight_decay", 0.01),
bias_correction=optimizer_config.get("bias_correction", False)
)
elif optimizer_name == "muon":
opt = optim.Muon(
learning_rate=lr,
momentum=optimizer_config.get("momentum", 0.95),
weight_decay=optimizer_config.get("weight_decay", 0.01),
nesterov=optimizer_config.get("nesterov", True),
ns_steps=optimizer_config.get("ns_steps", 5)
)
else:
raise ValueError(f"Unsupported optimizer: {args.optimizer}")
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
# Train model
train(