mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Adding multiple optimizers to mlx lm (#1315)
* initial commmit * adding more customized YAML configuartion * update YAML example file * Changed the switch to set opt_class * removing muon * using default arguments * udpate
This commit is contained in:
parent
56d2db23e1
commit
e150621095
@ -7,6 +7,15 @@ train: true
|
|||||||
# The fine-tuning method: "lora", "dora", or "full".
|
# The fine-tuning method: "lora", "dora", or "full".
|
||||||
fine_tune_type: lora
|
fine_tune_type: lora
|
||||||
|
|
||||||
|
# The Optimizer with its possible inputs
|
||||||
|
optimizer: adamw
|
||||||
|
# optimizer_config:
|
||||||
|
# adamw:
|
||||||
|
# betas: [0.9, 0.98]
|
||||||
|
# eps: 1e-6
|
||||||
|
# weight_decay: 0.05
|
||||||
|
# bias_correction: true
|
||||||
|
|
||||||
# Directory with {train, valid, test}.jsonl files
|
# Directory with {train, valid, test}.jsonl files
|
||||||
data: "/path/to/training/data"
|
data: "/path/to/training/data"
|
||||||
|
|
||||||
|
@ -43,6 +43,11 @@ CONFIG_DEFAULTS = {
|
|||||||
"model": "mlx_model",
|
"model": "mlx_model",
|
||||||
"train": False,
|
"train": False,
|
||||||
"fine_tune_type": "lora",
|
"fine_tune_type": "lora",
|
||||||
|
"optimizer": "adam",
|
||||||
|
"optimizer_config": {
|
||||||
|
"adam": {},
|
||||||
|
"adamw": {},
|
||||||
|
},
|
||||||
"data": "data/",
|
"data": "data/",
|
||||||
"seed": 0,
|
"seed": 0,
|
||||||
"num_layers": 16,
|
"num_layers": 16,
|
||||||
@ -95,14 +100,19 @@ def build_parser():
|
|||||||
choices=["lora", "dora", "full"],
|
choices=["lora", "dora", "full"],
|
||||||
help="Type of fine-tuning to perform: lora, dora, or full.",
|
help="Type of fine-tuning to perform: lora, dora, or full.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--optimizer",
|
||||||
|
type=str,
|
||||||
|
choices=["adam", "adamw"],
|
||||||
|
default=None,
|
||||||
|
help="Optimizer to use for training: adam or adamw",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--mask-prompt",
|
"--mask-prompt",
|
||||||
action="store_true",
|
action="store_true",
|
||||||
help="Mask the prompt in the loss when training",
|
help="Mask the prompt in the loss when training",
|
||||||
default=None,
|
default=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--num-layers",
|
"--num-layers",
|
||||||
type=int,
|
type=int,
|
||||||
@ -229,11 +239,21 @@ def train_model(
|
|||||||
)
|
)
|
||||||
|
|
||||||
model.train()
|
model.train()
|
||||||
opt = optim.Adam(
|
|
||||||
learning_rate=(
|
# Initialize the selected optimizer
|
||||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||||
)
|
|
||||||
)
|
optimizer_name = args.optimizer.lower()
|
||||||
|
optimizer_config = args.optimizer_config.get(optimizer_name, {})
|
||||||
|
|
||||||
|
if optimizer_name == "adam":
|
||||||
|
opt_class = optim.Adam
|
||||||
|
elif optimizer_name == "adamw":
|
||||||
|
opt_class = optim.AdamW
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
|
||||||
|
|
||||||
|
opt = opt_class(learning_rate=lr, **optimizer_config)
|
||||||
|
|
||||||
# Train model
|
# Train model
|
||||||
train(
|
train(
|
||||||
|
Loading…
Reference in New Issue
Block a user