mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 18:36:38 +08:00
Add checkpoints directory for adapter weights (#431)
* Add checkpoints directory for adapter weights The code was modified to create a checkpoints directory if it doesn't exist yet. Adapter weights are now saved to this checkpoints directory during the training iterations. Corrected indentation of Save adapter weights code because it was part of "if eval" * Fixing a blank added by mistake
This commit is contained in:
parent
70465b8cda
commit
4576946151
@ -127,6 +127,10 @@ def train(
|
||||
args: TrainingArgs = TrainingArgs(),
|
||||
loss: callable = default_loss,
|
||||
):
|
||||
# Create checkpoints directory if it does not exist
|
||||
if not os.path.exists('checkpoints'):
|
||||
os.makedirs('checkpoints')
|
||||
|
||||
# Create value and grad function for loss
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
|
||||
@ -191,12 +195,13 @@ def train(
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save adapter weights if needed
|
||||
if (it + 1) % args.steps_per_save == 0:
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
print(
|
||||
f"Iter {it + 1}: Saved adapter weights to {os.path.join(args.adapter_file)}."
|
||||
)
|
||||
# Save adapter weights if needed
|
||||
if (it + 1) % args.steps_per_save == 0:
|
||||
checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}"
|
||||
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
|
||||
print(
|
||||
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
|
||||
)
|
||||
# save final adapter weights
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
||||
|
Loading…
Reference in New Issue
Block a user