mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Adding multiple optimizers to mlx lm (#1315)
* initial commmit * adding more customized YAML configuartion * update YAML example file * Changed the switch to set opt_class * removing muon * using default arguments * udpate
This commit is contained in:
parent
56d2db23e1
commit
e150621095
@ -7,6 +7,15 @@ train: true
|
||||
# The fine-tuning method: "lora", "dora", or "full".
|
||||
fine_tune_type: lora
|
||||
|
||||
# The Optimizer with its possible inputs
|
||||
optimizer: adamw
|
||||
# optimizer_config:
|
||||
# adamw:
|
||||
# betas: [0.9, 0.98]
|
||||
# eps: 1e-6
|
||||
# weight_decay: 0.05
|
||||
# bias_correction: true
|
||||
|
||||
# Directory with {train, valid, test}.jsonl files
|
||||
data: "/path/to/training/data"
|
||||
|
||||
|
@ -43,6 +43,11 @@ CONFIG_DEFAULTS = {
|
||||
"model": "mlx_model",
|
||||
"train": False,
|
||||
"fine_tune_type": "lora",
|
||||
"optimizer": "adam",
|
||||
"optimizer_config": {
|
||||
"adam": {},
|
||||
"adamw": {},
|
||||
},
|
||||
"data": "data/",
|
||||
"seed": 0,
|
||||
"num_layers": 16,
|
||||
@ -95,14 +100,19 @@ def build_parser():
|
||||
choices=["lora", "dora", "full"],
|
||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--optimizer",
|
||||
type=str,
|
||||
choices=["adam", "adamw"],
|
||||
default=None,
|
||||
help="Optimizer to use for training: adam or adamw",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mask-prompt",
|
||||
action="store_true",
|
||||
help="Mask the prompt in the loss when training",
|
||||
default=None,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
type=int,
|
||||
@ -229,11 +239,21 @@ def train_model(
|
||||
)
|
||||
|
||||
model.train()
|
||||
opt = optim.Adam(
|
||||
learning_rate=(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
|
||||
# Initialize the selected optimizer
|
||||
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
|
||||
optimizer_name = args.optimizer.lower()
|
||||
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||
|
||||
if optimizer_name == "adam":
|
||||
opt_class = optim.Adam
|
||||
elif optimizer_name == "adamw":
|
||||
opt_class = optim.AdamW
|
||||
else:
|
||||
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||
|
||||
opt = opt_class(learning_rate=lr, **optimizer_config)
|
||||
|
||||
# Train model
|
||||
train(
|
||||
|
Loading…
Reference in New Issue
Block a user