mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
LoRA: adapter file Support path information (#505)
* LoRA: adapter file Support path information * fix pre-commit lint * from os.path to pathlib.Path * Update llms/mlx_lm/tuner/trainer.py Co-authored-by: Awni Hannun <awni.hannun@gmail.com> * rename check_checkpoints_path to checkpoints_path * fix pre-commit lint --------- Co-authored-by: Awni Hannun <awni.hannun@gmail.com>
This commit is contained in:
parent
ae48563378
commit
f03c8a7b44
@ -1,6 +1,6 @@
|
|||||||
import os
|
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
@ -143,9 +143,17 @@ def train(
|
|||||||
):
|
):
|
||||||
print(f"Starting training..., iters: {args.iters}")
|
print(f"Starting training..., iters: {args.iters}")
|
||||||
|
|
||||||
|
def checkpoints_path(adapter_file) -> str:
|
||||||
|
checkpoints_path = Path("checkpoints")
|
||||||
|
if Path(adapter_file).parent:
|
||||||
|
checkpoints_path = Path(adapter_file).parent / "checkpoints"
|
||||||
|
|
||||||
|
checkpoints_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
return str(checkpoints_path)
|
||||||
|
|
||||||
# Create checkpoints directory if it does not exist
|
# Create checkpoints directory if it does not exist
|
||||||
if not os.path.exists("checkpoints"):
|
adapter_path = checkpoints_path(args.adapter_file)
|
||||||
os.makedirs("checkpoints")
|
|
||||||
|
|
||||||
# Create value and grad function for loss
|
# Create value and grad function for loss
|
||||||
loss_value_and_grad = nn.value_and_grad(model, loss)
|
loss_value_and_grad = nn.value_and_grad(model, loss)
|
||||||
@ -241,15 +249,15 @@ def train(
|
|||||||
|
|
||||||
# Save adapter weights if needed
|
# Save adapter weights if needed
|
||||||
if (it + 1) % args.steps_per_save == 0:
|
if (it + 1) % args.steps_per_save == 0:
|
||||||
checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}"
|
checkpoint_adapter_file = (
|
||||||
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
|
f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}"
|
||||||
print(
|
|
||||||
f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}."
|
|
||||||
)
|
)
|
||||||
|
save_adapter(model=model, adapter_file=checkpoint_adapter_file)
|
||||||
|
print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.")
|
||||||
|
|
||||||
# save final adapter weights
|
# save final adapter weights
|
||||||
save_adapter(model=model, adapter_file=args.adapter_file)
|
save_adapter(model=model, adapter_file=args.adapter_file)
|
||||||
print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")
|
print(f"Saved final adapter weights to {args.adapter_file}.")
|
||||||
|
|
||||||
|
|
||||||
def save_adapter(
|
def save_adapter(
|
||||||
|
Loading…
Reference in New Issue
Block a user