From eed093b0ecea1c21428a01bfc0a059652de28504 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 1 Mar 2025 15:00:29 +0100 Subject: [PATCH] adding more customized YAML configuartion --- llms/mlx_lm/lora.py | 53 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 46 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 0cc40508..ed19450f 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -44,6 +44,25 @@ CONFIG_DEFAULTS = { "train": False, "fine_tune_type": "lora", "optimizer": "adam", + "optimizer_config": { + "adam": { + "betas": [0.9, 0.999], + "eps": 1e-8, + "bias_correction": False + }, + "adamw": { + "betas": [0.9, 0.999], + "eps": 1e-8, + "weight_decay": 0.01, + "bias_correction": False + }, + "muon": { + "momentum": 0.95, + "weight_decay": 0.01, + "nesterov": True, + "ns_steps": 5 + } + }, "data": "data/", "seed": 0, "num_layers": 16, @@ -239,14 +258,34 @@ def train_model( # Initialize the selected optimizer lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - if args.optimizer.lower() == "adam": - opt = optim.Adam(learning_rate=lr) - elif args.optimizer.lower() == "adamw": - opt = optim.AdamW(learning_rate=lr) - elif args.optimizer.lower() == "muon": - opt = optim.Muon(learning_rate=lr) + optimizer_name = args.optimizer.lower() + optimizer_config = args.optimizer_config.get(optimizer_name, {}) + + if optimizer_name == "adam": + opt = optim.Adam( + learning_rate=lr, + betas=optimizer_config.get("betas", [0.9, 0.999]), + eps=optimizer_config.get("eps", 1e-8), + bias_correction=optimizer_config.get("bias_correction", False) + ) + elif optimizer_name == "adamw": + opt = optim.AdamW( + learning_rate=lr, + betas=optimizer_config.get("betas", [0.9, 0.999]), + eps=optimizer_config.get("eps", 1e-8), + weight_decay=optimizer_config.get("weight_decay", 0.01), + bias_correction=optimizer_config.get("bias_correction", False) + ) + elif optimizer_name == "muon": + opt = optim.Muon( + learning_rate=lr, + momentum=optimizer_config.get("momentum", 0.95), + weight_decay=optimizer_config.get("weight_decay", 0.01), + nesterov=optimizer_config.get("nesterov", True), + ns_steps=optimizer_config.get("ns_steps", 5) + ) else: - raise ValueError(f"Unsupported optimizer: {args.optimizer}") + raise ValueError(f"Unsupported optimizer: {optimizer_name}") # Train model train(