initial commmit

This commit is contained in:
Goekdeniz-Guelmez 2025-03-01 14:56:06 +01:00
parent 845cd8c01e
commit b0a2edbcf3

View File

@ -43,6 +43,7 @@ CONFIG_DEFAULTS = {
"model": "mlx_model", "model": "mlx_model",
"train": False, "train": False,
"fine_tune_type": "lora", "fine_tune_type": "lora",
"optimizer": "adam",
"data": "data/", "data": "data/",
"seed": 0, "seed": 0,
"num_layers": 16, "num_layers": 16,
@ -95,14 +96,19 @@ def build_parser():
choices=["lora", "dora", "full"], choices=["lora", "dora", "full"],
help="Type of fine-tuning to perform: lora, dora, or full.", help="Type of fine-tuning to perform: lora, dora, or full.",
) )
parser.add_argument(
'--optimizer',
type=str,
choices=["adam", "adamw", "muon"],
default="adam",
help="Optimizer to use for training: adam, adamw, or muon",
)
parser.add_argument( parser.add_argument(
"--mask-prompt", "--mask-prompt",
action="store_true", action="store_true",
help="Mask the prompt in the loss when training", help="Mask the prompt in the loss when training",
default=None, default=None,
) )
parser.add_argument( parser.add_argument(
"--num-layers", "--num-layers",
type=int, type=int,
@ -229,11 +235,18 @@ def train_model(
) )
model.train() model.train()
opt = optim.Adam(
learning_rate=( # Initialize the selected optimizer
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
) if args.optimizer.lower() == "adam":
opt = optim.Adam(learning_rate=lr)
elif args.optimizer.lower() == "adamw":
opt = optim.AdamW(learning_rate=lr)
elif args.optimizer.lower() == "muon":
opt = optim.Muon(learning_rate=lr)
else:
raise ValueError(f"Unsupported optimizer: {args.optimizer}")
# Train model # Train model
train( train(