Adding multiple optimizers to mlx lm (#1315)

* initial commmit

* adding more customized YAML configuartion

* update YAML example file

* Changed the switch to set opt_class

* removing muon

* using default arguments

* udpate
This commit is contained in:
Gökdeniz Gülmez 2025-03-05 22:54:54 +01:00 committed by GitHub
parent 56d2db23e1
commit e150621095
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 36 additions and 7 deletions

View File

@ -7,6 +7,15 @@ train: true
# The fine-tuning method: "lora", "dora", or "full".
fine_tune_type: lora
# The Optimizer with its possible inputs
optimizer: adamw
# optimizer_config:
# adamw:
# betas: [0.9, 0.98]
# eps: 1e-6
# weight_decay: 0.05
# bias_correction: true
# Directory with {train, valid, test}.jsonl files
data: "/path/to/training/data"

View File

@ -43,6 +43,11 @@ CONFIG_DEFAULTS = {
"model": "mlx_model",
"train": False,
"fine_tune_type": "lora",
"optimizer": "adam",
"optimizer_config": {
"adam": {},
"adamw": {},
},
"data": "data/",
"seed": 0,
"num_layers": 16,
@ -95,14 +100,19 @@ def build_parser():
choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.",
)
parser.add_argument(
"--optimizer",
type=str,
choices=["adam", "adamw"],
default=None,
help="Optimizer to use for training: adam or adamw",
)
parser.add_argument(
"--mask-prompt",
action="store_true",
help="Mask the prompt in the loss when training",
default=None,
)
parser.add_argument(
"--num-layers",
type=int,
@ -229,11 +239,21 @@ def train_model(
)
model.train()
opt = optim.Adam(
learning_rate=(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Initialize the selected optimizer
lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
optimizer_name = args.optimizer.lower()
optimizer_config = args.optimizer_config.get(optimizer_name, {})
if optimizer_name == "adam":
opt_class = optim.Adam
elif optimizer_name == "adamw":
opt_class = optim.AdamW
else:
raise ValueError(f"Unsupported optimizer: {optimizer_name}")
opt = opt_class(learning_rate=lr, **optimizer_config)
# Train model
train(