diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index d32bfe6d..0cc40508 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -43,6 +43,7 @@ CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, "fine_tune_type": "lora", + "optimizer": "adam", "data": "data/", "seed": 0, "num_layers": 16, @@ -95,14 +96,19 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) - + parser.add_argument( + '--optimizer', + type=str, + choices=["adam", "adamw", "muon"], + default="adam", + help="Optimizer to use for training: adam, adamw, or muon", + ) parser.add_argument( "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", default=None, ) - parser.add_argument( "--num-layers", type=int, @@ -229,11 +235,18 @@ def train_model( ) model.train() - opt = optim.Adam( - learning_rate=( - build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - ) - ) + + # Initialize the selected optimizer + lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate + + if args.optimizer.lower() == "adam": + opt = optim.Adam(learning_rate=lr) + elif args.optimizer.lower() == "adamw": + opt = optim.AdamW(learning_rate=lr) + elif args.optimizer.lower() == "muon": + opt = optim.Muon(learning_rate=lr) + else: + raise ValueError(f"Unsupported optimizer: {args.optimizer}") # Train model train(