LoRA: adapter file Support path information (#505)

* LoRA: adapter file Support path information

* fix pre-commit lint

* from os.path to pathlib.Path

* Update llms/mlx_lm/tuner/trainer.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* rename check_checkpoints_path to checkpoints_path

* fix pre-commit lint

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Madroid Ma 2024-03-01 14:20:49 +08:00 committed by GitHub
parent ae48563378
commit f03c8a7b44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import os
import time import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -143,9 +143,17 @@ def train(
): ):
print(f"Starting training..., iters: {args.iters}") print(f"Starting training..., iters: {args.iters}")
def checkpoints_path(adapter_file) -> str:
checkpoints_path = Path("checkpoints")
if Path(adapter_file).parent:
checkpoints_path = Path(adapter_file).parent / "checkpoints"
checkpoints_path.mkdir(parents=True, exist_ok=True)
return str(checkpoints_path)
# Create checkpoints directory if it does not exist # Create checkpoints directory if it does not exist
if not os.path.exists("checkpoints"): adapter_path = checkpoints_path(args.adapter_file)
os.makedirs("checkpoints")
# Create value and grad function for loss # Create value and grad function for loss
loss_value_and_grad = nn.value_and_grad(model, loss) loss_value_and_grad = nn.value_and_grad(model, loss)
@ -241,15 +249,15 @@ def train(
# Save adapter weights if needed # Save adapter weights if needed
if (it + 1) % args.steps_per_save == 0: if (it + 1) % args.steps_per_save == 0:
checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}" checkpoint_adapter_file = (
save_adapter(model=model, adapter_file=checkpoint_adapter_file) f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}"
print(
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
) )
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.")
# save final adapter weights # save final adapter weights
save_adapter(model=model, adapter_file=args.adapter_file) save_adapter(model=model, adapter_file=args.adapter_file)
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.") print(f"Saved final adapter weights to {args.adapter_file}.")
def save_adapter( def save_adapter(