mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 16:16:27 +08:00
adding more customized YAML configuartion
This commit is contained in:
parent
b0a2edbcf3
commit
eed093b0ec
@ -44,6 +44,25 @@ CONFIG_DEFAULTS = {
|
||||
"train": False,
|
||||
"fine_tune_type": "lora",
|
||||
"optimizer": "adam",
|
||||
"optimizer_config": {
|
||||
"adam": {
|
||||
"betas": [0.9, 0.999],
|
||||
"eps": 1e-8,
|
||||
"bias_correction": False
|
||||
},
|
||||
"adamw": {
|
||||
"betas": [0.9, 0.999],
|
||||
"eps": 1e-8,
|
||||
"weight_decay": 0.01,
|
||||
"bias_correction": False
|
||||
},
|
||||
"muon": {
|
||||
"momentum": 0.95,
|
||||
"weight_decay": 0.01,
|
||||
"nesterov": True,
|
||||
"ns_steps": 5
|
||||
}
|
||||
},
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"num_layers": 16,
|
||||
@ -239,14 +258,34 @@ def train_model(
|
||||
# Initialize the selected optimizer
|
||||
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
|
||||
if args.optimizer.lower() == "adam":
|
||||
opt = optim.Adam(learning_rate=lr)
|
||||
elif args.optimizer.lower() == "adamw":
|
||||
opt = optim.AdamW(learning_rate=lr)
|
||||
elif args.optimizer.lower() == "muon":
|
||||
opt = optim.Muon(learning_rate=lr)
|
||||
optimizer_name = args.optimizer.lower()
|
||||
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||
|
||||
if optimizer_name == "adam":
|
||||
opt = optim.Adam(
|
||||
learning_rate=lr,
|
||||
betas=optimizer_config.get("betas", [0.9, 0.999]),
|
||||
eps=optimizer_config.get("eps", 1e-8),
|
||||
bias_correction=optimizer_config.get("bias_correction", False)
|
||||
)
|
||||
elif optimizer_name == "adamw":
|
||||
opt = optim.AdamW(
|
||||
learning_rate=lr,
|
||||
betas=optimizer_config.get("betas", [0.9, 0.999]),
|
||||
eps=optimizer_config.get("eps", 1e-8),
|
||||
weight_decay=optimizer_config.get("weight_decay", 0.01),
|
||||
bias_correction=optimizer_config.get("bias_correction", False)
|
||||
)
|
||||
elif optimizer_name == "muon":
|
||||
opt = optim.Muon(
|
||||
learning_rate=lr,
|
||||
momentum=optimizer_config.get("momentum", 0.95),
|
||||
weight_decay=optimizer_config.get("weight_decay", 0.01),
|
||||
nesterov=optimizer_config.get("nesterov", True),
|
||||
ns_steps=optimizer_config.get("ns_steps", 5)
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer: {args.optimizer}")
|
||||
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||
|
||||
# Train model
|
||||
train(
|
||||
|
Loading…
Reference in New Issue
Block a user