From b0a2edbcf33a299f322b9356891472185d6bb443 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 1 Mar 2025 14:56:06 +0100 Subject: [PATCH] initial commmit --- llms/mlx_lm/lora.py | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index d32bfe6d..0cc40508 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -43,6 +43,7 @@ CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, "fine_tune_type": "lora", + "optimizer": "adam", "data": "data/", "seed": 0, "num_layers": 16, @@ -95,14 +96,19 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) - + parser.add_argument( + '--optimizer', + type=str, + choices=["adam", "adamw", "muon"], + default="adam", + help="Optimizer to use for training: adam, adamw, or muon", + ) parser.add_argument( "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", default=None, ) - parser.add_argument( "--num-layers", type=int, @@ -229,11 +235,18 @@ def train_model( ) model.train() - opt = optim.Adam( - learning_rate=( - build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - ) - ) + + # Initialize the selected optimizer + lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate + + if args.optimizer.lower() == "adam": + opt = optim.Adam(learning_rate=lr) + elif args.optimizer.lower() == "adamw": + opt = optim.AdamW(learning_rate=lr) + elif args.optimizer.lower() == "muon": + opt = optim.Muon(learning_rate=lr) + else: + raise ValueError(f"Unsupported optimizer: {args.optimizer}") # Train model train(