From f03c8a7b44dbf97a1cc4512ba73563cd7b8272a2 Mon Sep 17 00:00:00 2001 From: Madroid Ma Date: Fri, 1 Mar 2024 14:20:49 +0800 Subject: [PATCH] LoRA: adapter file Support path information (#505) * LoRA: adapter file Support path information * fix pre-commit lint * from os.path to pathlib.Path * Update llms/mlx_lm/tuner/trainer.py Co-authored-by: Awni Hannun * rename check_checkpoints_path to checkpoints_path * fix pre-commit lint --------- Co-authored-by: Awni Hannun --- llms/mlx_lm/tuner/trainer.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index bc3f2811..43ab66a6 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -1,6 +1,6 @@ -import os import time from dataclasses import dataclass, field +from pathlib import Path import mlx.core as mx import mlx.nn as nn @@ -143,9 +143,17 @@ def train( ): print(f"Starting training..., iters: {args.iters}") + def checkpoints_path(adapter_file) -> str: + checkpoints_path = Path("checkpoints") + if Path(adapter_file).parent: + checkpoints_path = Path(adapter_file).parent / "checkpoints" + + checkpoints_path.mkdir(parents=True, exist_ok=True) + + return str(checkpoints_path) + # Create checkpoints directory if it does not exist - if not os.path.exists("checkpoints"): - os.makedirs("checkpoints") + adapter_path = checkpoints_path(args.adapter_file) # Create value and grad function for loss loss_value_and_grad = nn.value_and_grad(model, loss) @@ -241,15 +249,15 @@ def train( # Save adapter weights if needed if (it + 1) % args.steps_per_save == 0: - checkpoint_adapter_file = f"checkpoints/{it + 1}_{args.adapter_file}" - save_adapter(model=model, adapter_file=checkpoint_adapter_file) - print( - f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}." + checkpoint_adapter_file = ( + f"{adapter_path}/{it + 1}_{Path(args.adapter_file).name}" ) + save_adapter(model=model, adapter_file=checkpoint_adapter_file) + print(f"Iter {it + 1}: Saved adapter weights to {checkpoint_adapter_file}.") # save final adapter weights save_adapter(model=model, adapter_file=args.adapter_file) - print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.") + print(f"Saved final adapter weights to {args.adapter_file}.") def save_adapter(