From 73cc094681cba5c6cc9d72c23bbda910b5584c17 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 8 Mar 2025 22:36:20 +0100 Subject: [PATCH] fix optimizer --- llms/mlx_lm/lora.py | 35 +++++++++++++++-------------------- 1 file changed, 15 insertions(+), 20 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 797b0d8a..724b4297 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -255,11 +255,21 @@ def train_model( save_config(vars(args), adapter_path / "adapter_config.json") model.train() - opt = optim.Adam( # need to correct that part - learning_rate=( - build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - ) - ) + + # Initialize the selected optimizer + lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate + + optimizer_name = args.optimizer.lower() + optimizer_config = args.optimizer_config.get(optimizer_name, {}) + + if optimizer_name == "adam": + opt_class = optim.Adam + elif optimizer_name == "adamw": + opt_class = optim.AdamW + else: + raise ValueError(f"Unsupported optimizer: {optimizer_name}") + + opt = opt_class(learning_rate=lr, **optimizer_config) # Train model based on training mode if args.training_mode == "orpo": @@ -298,21 +308,6 @@ def train_model( max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint ) - - # Initialize the selected optimizer - lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - - optimizer_name = args.optimizer.lower() - optimizer_config = args.optimizer_config.get(optimizer_name, {}) - - if optimizer_name == "adam": - opt_class = optim.Adam - elif optimizer_name == "adamw": - opt_class = optim.AdamW - else: - raise ValueError(f"Unsupported optimizer: {optimizer_name}") - - opt = opt_class(learning_rate=lr, **optimizer_config) train( model=model,