LoRA: adapter file Support path information (#505)

* LoRA: adapter file Support path information

* fix pre-commit lint

* from os.path to pathlib.Path

* Update llms/mlx_lm/tuner/trainer.py

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>

* rename check_checkpoints_path to checkpoints_path

* fix pre-commit lint

---------

Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
Madroid Ma 2024-03-01 14:20:49 +08:00 committed by GitHub
parent ae48563378
commit f03c8a7b44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1,6 +1,6 @@
import os
import time
from dataclasses import dataclass, field
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
@ -143,9 +143,17 @@ def train(
):
print(f"Starting training..., iters: {args.iters}")
def checkpoints_path(adapter_file) -> str:
checkpoints_path = Path("checkpoints")
if Path(adapter_file).parent:
checkpoints_path = Path(adapter_file).parent / "checkpoints"
checkpoints_path.mkdir(parents=True, exist_ok=True)
return str(checkpoints_path)
# Create checkpoints directory if it does not exist
if not os.path.exists("checkpoints"):
os.makedirs("checkpoints")
adapter_path = checkpoints_path(args.adapter_file)
# Create value and grad function for loss
loss_value_and_grad = nn.value_and_grad(model, loss)
@ -241,15 +249,15 @@ def train(
# Save adapter weights if needed
if (it + 1) % args.steps_per_save == 0:
checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}"
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
print(
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
checkpoint_adapter_file = (
f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}"
)
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.")
# save final adapter weights
save_adapter(model=model, adapter_file=args.adapter_file)
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
print(f"Saved final adapter weights to {args.adapter_file}.")
def save_adapter(