From e15062109568571aec0e2f099533ad580f0fcaf5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?G=C3=B6kdeniz=20G=C3=BClmez?= <60228478+Goekdeniz-Guelmez@users.noreply.github.com> Date: Wed, 5 Mar 2025 22:54:54 +0100 Subject: [PATCH] Adding multiple optimizers to mlx lm (#1315) * initial commmit * adding more customized YAML configuartion * update YAML example file * Changed the switch to set opt_class * removing muon * using default arguments * udpate --- llms/mlx_lm/examples/lora_config.yaml | 9 +++++++ llms/mlx_lm/lora.py | 34 +++++++++++++++++++++------ 2 files changed, 36 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index 530272c7..36bc1dff 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -7,6 +7,15 @@ train: true # The fine-tuning method: "lora", "dora", or "full". fine_tune_type: lora +# The Optimizer with its possible inputs +optimizer: adamw +# optimizer_config: +# adamw: +# betas: [0.9, 0.98] +# eps: 1e-6 +# weight_decay: 0.05 +# bias_correction: true + # Directory with {train, valid, test}.jsonl files data: "/path/to/training/data" diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index d32bfe6d..042b40e2 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -43,6 +43,11 @@ CONFIG_DEFAULTS = { "model": "mlx_model", "train": False, "fine_tune_type": "lora", + "optimizer": "adam", + "optimizer_config": { + "adam": {}, + "adamw": {}, + }, "data": "data/", "seed": 0, "num_layers": 16, @@ -95,14 +100,19 @@ def build_parser(): choices=["lora", "dora", "full"], help="Type of fine-tuning to perform: lora, dora, or full.", ) - + parser.add_argument( + "--optimizer", + type=str, + choices=["adam", "adamw"], + default=None, + help="Optimizer to use for training: adam or adamw", + ) parser.add_argument( "--mask-prompt", action="store_true", help="Mask the prompt in the loss when training", default=None, ) - parser.add_argument( "--num-layers", type=int, @@ -229,11 +239,21 @@ def train_model( ) model.train() - opt = optim.Adam( - learning_rate=( - build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - ) - ) + + # Initialize the selected optimizer + lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate + + optimizer_name = args.optimizer.lower() + optimizer_config = args.optimizer_config.get(optimizer_name, {}) + + if optimizer_name == "adam": + opt_class = optim.Adam + elif optimizer_name == "adamw": + opt_class = optim.AdamW + else: + raise ValueError(f"Unsupported optimizer: {optimizer_name}") + + opt = opt_class(learning_rate=lr, **optimizer_config) # Train model train(