mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 18:26:37 +08:00
Changed the switch to set opt_class
This commit is contained in:
parent
60df71bcbc
commit
64ed426518
@ -45,23 +45,14 @@ CONFIG_DEFAULTS = {
|
|||||||
"fine_tune_type": "lora",
|
"fine_tune_type": "lora",
|
||||||
"optimizer": "adam",
|
"optimizer": "adam",
|
||||||
"optimizer_config": {
|
"optimizer_config": {
|
||||||
"adam": {
|
"adam": {"bias_correction": False},
|
||||||
"betas": [0.9, 0.999],
|
"adamw": {"weight_decay": 0.01, "bias_correction": False},
|
||||||
"eps": 1e-8,
|
|
||||||
"bias_correction": False
|
|
||||||
},
|
|
||||||
"adamw": {
|
|
||||||
"betas": [0.9, 0.999],
|
|
||||||
"eps": 1e-8,
|
|
||||||
"weight_decay": 0.01,
|
|
||||||
"bias_correction": False
|
|
||||||
},
|
|
||||||
"muon": {
|
"muon": {
|
||||||
"momentum": 0.95,
|
"momentum": 0.95,
|
||||||
"weight_decay": 0.01,
|
"weight_decay": 0.01,
|
||||||
"nesterov": True,
|
"nesterov": True,
|
||||||
"ns_steps": 5
|
"ns_steps": 5,
|
||||||
}
|
},
|
||||||
},
|
},
|
||||||
"data": "data/",
|
"data": "data/",
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
@ -116,7 +107,7 @@ def build_parser():
|
|||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
'--optimizer',
|
"--optimizer",
|
||||||
type=str,
|
type=str,
|
||||||
choices=["adam", "adamw", "muon"],
|
choices=["adam", "adamw", "muon"],
|
||||||
default="adam",
|
default="adam",
|
||||||
@ -262,31 +253,16 @@ def train_model(
|
|||||||
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||||
|
|
||||||
if optimizer_name == "adam":
|
if optimizer_name == "adam":
|
||||||
opt = optim.Adam(
|
opt_class = optim.Adam
|
||||||
learning_rate=lr,
|
|
||||||
betas=optimizer_config.get("betas", [0.9, 0.999]),
|
|
||||||
eps=optimizer_config.get("eps", 1e-8),
|
|
||||||
bias_correction=optimizer_config.get("bias_correction", False)
|
|
||||||
)
|
|
||||||
elif optimizer_name == "adamw":
|
elif optimizer_name == "adamw":
|
||||||
opt = optim.AdamW(
|
opt_class = optim.AdamW
|
||||||
learning_rate=lr,
|
|
||||||
betas=optimizer_config.get("betas", [0.9, 0.999]),
|
|
||||||
eps=optimizer_config.get("eps", 1e-8),
|
|
||||||
weight_decay=optimizer_config.get("weight_decay", 0.01),
|
|
||||||
bias_correction=optimizer_config.get("bias_correction", False)
|
|
||||||
)
|
|
||||||
elif optimizer_name == "muon":
|
elif optimizer_name == "muon":
|
||||||
opt = optim.Muon(
|
opt_class = optim.Muon
|
||||||
learning_rate=lr,
|
|
||||||
momentum=optimizer_config.get("momentum", 0.95),
|
|
||||||
weight_decay=optimizer_config.get("weight_decay", 0.01),
|
|
||||||
nesterov=optimizer_config.get("nesterov", True),
|
|
||||||
ns_steps=optimizer_config.get("ns_steps", 5)
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||||
|
|
||||||
|
opt = opt_class(learning_rate=lr, **optimizer_config)
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
train(
|
||||||
model=model,
|
model=model,
|
||||||
|
Loading…
Reference in New Issue
Block a user