From 45769461512f227abc3d0ba40e49118e851259db Mon Sep 17 00:00:00 2001 From: Ivan Fioravanti Date: Mon, 12 Feb 2024 19:50:05 +0100 Subject: [PATCH] Add checkpoints directory for adapter weights (#431) * Add checkpoints directory for adapter weights The code was modified to create a checkpoints directory if it doesn't exist yet. Adapter weights are now saved to this checkpoints directory during the training iterations. Corrected indentation of Save adapter weights code because it was part of "if eval" * Fixing a blank added by mistake --- llms/mlx_lm/tuner/trainer.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 2d92a98f..2bfc8142 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -127,6 +127,10 @@ def train( args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, ): + # Create checkpoints directory if it does not exist + if not os.path.exists('checkpoints'): + os.makedirs('checkpoints') + # Create value and grad function for loss loss_value_and_grad = nn.value_and_grad(model, loss) @@ -191,12 +195,13 @@ def train( start = time.perf_counter() - # Save adapter weights if needed - if (it + 1) % args.steps_per_save == 0: - save_adapter(model=model, adapter_file=args.adapter_file) - print( - f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}." - ) + # Save adapter weights if needed + if (it + 1) % args.steps_per_save == 0: + checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}" + save_adapter(model=model, adapter_file=checkpoint_adapter_file) + print( + f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}." + ) # save final adapter weights save_adapter(model=model, adapter_file=args.adapter_file) print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")