diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index ed19450f..eefe97f5 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -45,23 +45,14 @@ CONFIG_DEFAULTS = { "fine_tune_type": "lora", "optimizer": "adam", "optimizer_config": { - "adam": { - "betas": [0.9, 0.999], - "eps": 1e-8, - "bias_correction": False - }, - "adamw": { - "betas": [0.9, 0.999], - "eps": 1e-8, - "weight_decay": 0.01, - "bias_correction": False - }, + "adam": {"bias_correction": False}, + "adamw": {"weight_decay": 0.01, "bias_correction": False}, "muon": { "momentum": 0.95, "weight_decay": 0.01, "nesterov": True, - "ns_steps": 5 - } + "ns_steps": 5, + }, }, "data": "data/", "seed": 0, @@ -116,7 +107,7 @@ def build_parser(): help="Type of fine-tuning to perform: lora, dora, or full.", ) parser.add_argument( - '--optimizer', + "--optimizer", type=str, choices=["adam", "adamw", "muon"], default="adam", @@ -257,36 +248,21 @@ def train_model( # Initialize the selected optimizer lr = build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate - + optimizer_name = args.optimizer.lower() optimizer_config = args.optimizer_config.get(optimizer_name, {}) - + if optimizer_name == "adam": - opt = optim.Adam( - learning_rate=lr, - betas=optimizer_config.get("betas", [0.9, 0.999]), - eps=optimizer_config.get("eps", 1e-8), - bias_correction=optimizer_config.get("bias_correction", False) - ) + opt_class = optim.Adam elif optimizer_name == "adamw": - opt = optim.AdamW( - learning_rate=lr, - betas=optimizer_config.get("betas", [0.9, 0.999]), - eps=optimizer_config.get("eps", 1e-8), - weight_decay=optimizer_config.get("weight_decay", 0.01), - bias_correction=optimizer_config.get("bias_correction", False) - ) + opt_class = optim.AdamW elif optimizer_name == "muon": - opt = optim.Muon( - learning_rate=lr, - momentum=optimizer_config.get("momentum", 0.95), - weight_decay=optimizer_config.get("weight_decay", 0.01), - nesterov=optimizer_config.get("nesterov", True), - ns_steps=optimizer_config.get("ns_steps", 5) - ) + opt_class = optim.Muon else: raise ValueError(f"Unsupported optimizer: {optimizer_name}") + opt = opt_class(learning_rate=lr, **optimizer_config) + # Train model train( model=model,