diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index d32bfe6d..d4ce57a3 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -63,6 +63,7 @@ CONFIG_DEFAULTS = { "lr_schedule": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "mask_prompt": False, + "report_to_wandb": False } @@ -169,6 +170,12 @@ def build_parser(): help="Use gradient checkpointing to reduce memory use.", default=None, ) + parser.add_argument( + "--report-to-wandb", + action="store_true", + help="Report the training args to WandB.", + default=None, + ) parser.add_argument("--seed", type=int, help="The PRNG seed") return parser @@ -267,6 +274,26 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set def run(args, training_callback: TrainingCallback = None): np.random.seed(args.seed) + # Initialize WandB if requested + if args.report_to_wandb: + import wandb + wandb.init(project="mlx-finetuning", config=vars(args)) + + # Create a simple wandb callback that wraps the existing one + original_callback = training_callback + class WandBCallback(TrainingCallback): + def on_train_loss_report(self, train_info: dict): + wandb.log(train_info) + if original_callback: + original_callback.on_train_loss_report(train_info) + + def on_val_loss_report(self, val_info: dict): + wandb.log(val_info) + if original_callback: + original_callback.on_val_loss_report(val_info) + + training_callback = WandBCallback() + print("Loading pretrained model") model, tokenizer = load(args.model)