diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 2d92a98f..2bfc8142 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -127,6 +127,10 @@ def train( args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, ): + # Create checkpoints directory if it does not exist + if not os.path.exists('checkpoints'): + os.makedirs('checkpoints') + # Create value and grad function for loss loss_value_and_grad = nn.value_and_grad(model, loss) @@ -191,12 +195,13 @@ def train( start = time.perf_counter() - # Save adapter weights if needed - if (it + 1) % args.steps_per_save == 0: - save_adapter(model=model, adapter_file=args.adapter_file) - print( - f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}." - ) + # Save adapter weights if needed + if (it + 1) % args.steps_per_save == 0: + checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}" + save_adapter(model=model, adapter_file=checkpoint_adapter_file) + print( + f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}." + ) # save final adapter weights save_adapter(model=model, adapter_file=args.adapter_file) print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")