mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
LoRA: adapter file Support path information (#505)
* LoRA: adapter file Support path information * fix pre-commit lint * from os.path to pathlib.Path * Update llms/mlx_lm/tuner/trainer.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * rename check_checkpoints_path to checkpoints_path * fix pre-commit lint --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
ae48563378
commit
f03c8a7b44
@ -1,6 +1,6 @@
|
||||
import os
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
@ -143,9 +143,17 @@ def train(
|
||||
):
|
||||
print(f"Starting training..., iters: {args.iters}")
|
||||
|
||||
def checkpoints_path(adapter_file) -> str:
|
||||
checkpoints_path = Path("checkpoints")
|
||||
if Path(adapter_file).parent:
|
||||
checkpoints_path = Path(adapter_file).parent / "checkpoints"
|
||||
|
||||
checkpoints_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return str(checkpoints_path)
|
||||
|
||||
# Create checkpoints directory if it does not exist
|
||||
if not os.path.exists("checkpoints"):
|
||||
os.makedirs("checkpoints")
|
||||
adapter_path = checkpoints_path(args.adapter_file)
|
||||
|
||||
# Create value and grad function for loss
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||
@ -241,15 +249,15 @@ def train(
|
||||
|
||||
# Save adapter weights if needed
|
||||
if (it + 1) % args.steps_per_save == 0:
|
||||
checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}"
|
||||
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
|
||||
print(
|
||||
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
|
||||
checkpoint_adapter_file = (
|
||||
f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}"
|
||||
)
|
||||
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
|
||||
print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.")
|
||||
|
||||
# save final adapter weights
|
||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
||||
print(f"Saved final adapter weights to {args.adapter_file}.")
|
||||
|
||||
|
||||
def save_adapter(
|
||||
|
Loading…
Reference in New Issue
Block a user