From 0ba466369faf7f42040a52393cd00d56b7f4e692 Mon Sep 17 00:00:00 2001 From: Madroid Ma Date: Fri, 16 Feb 2024 22:04:57 +0800 Subject: [PATCH] LoRA: add training callbacks (#414) * LoRA: add training callbacks * LoRA: add trained tokens print & callback --- llms/mlx_lm/lora.py | 1 - llms/mlx_lm/tuner/trainer.py | 37 +++++++++++++++++++++++++++++++----- 2 files changed, 32 insertions(+), 6 deletions(-) diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 75093080..9e26225b 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -7,7 +7,6 @@ import mlx.optimizers as optim import numpy as np from mlx.utils import tree_flatten -from .tuner.lora import LoRALinear from .tuner.trainer import TrainingArgs, evaluate, train from .tuner.utils import linear_to_lora_layers from .utils import generate, load diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index b2c5ba69..ae7e1fc7 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -119,6 +119,17 @@ def evaluate( return np.sum(all_losses) / ntokens +class TrainingCallback: + + def on_train_loss_report(self, steps: int, loss: float, it_sec: float, tokens_sec: float, trained_tokens: int): + """Called to report training loss at specified intervals.""" + pass + + def on_val_loss_report(self, steps: int, loss: float, val_time: float): + """Called to report validation loss at specified intervals or the beginning.""" + pass + + def train( model, tokenizer, @@ -127,8 +138,11 @@ def train( val_dataset, args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, - iterate_batches: callable = iterate_batches + iterate_batches: callable = iterate_batches, + training_callback=None, ): + print(f"Starting training..., iters: {args.iters}") + # Create checkpoints directory if it does not exist if not os.path.exists("checkpoints"): os.makedirs("checkpoints") @@ -138,7 +152,7 @@ def train( losses = [] n_tokens = 0 - print("Starting training..., iters:", args.iters) + trained_tokens = 0 # Main training loop start = time.perf_counter() for it, batch in zip( @@ -168,11 +182,19 @@ def train( train_loss = np.mean(losses) stop = time.perf_counter() + it_sec = args.steps_per_report / (stop - start) + tokens_sec = float(n_tokens) / (stop - start) + trained_tokens += n_tokens print( f"Iter {it + 1}: Train loss {train_loss:.3f}, " - f"It/sec {args.steps_per_report / (stop - start):.3f}, " - f"Tokens/sec {float(n_tokens) / (stop - start):.3f}" + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}" ) + + if training_callback is not None: + training_callback.on_train_loss_report(it + 1, train_loss, it_sec, tokens_sec, trained_tokens) + losses = [] n_tokens = 0 start = time.perf_counter() @@ -190,12 +212,16 @@ def train( max_seq_length=args.max_seq_length, iterate_batches=iterate_batches ) + val_time = time.perf_counter() - stop print( f"Iter {it + 1}: " f"Val loss {val_loss:.3f}, " - f"Val took {(time.perf_counter() - stop):.3f}s" + f"Val took {val_time:.3f}s" ) + if training_callback is not None: + training_callback.on_val_loss_report(it + 1, val_loss, val_time) + start = time.perf_counter() # Save adapter weights if needed @@ -205,6 +231,7 @@ def train( print( f"Iter {it + 1}: Saved adapter weights to {os.path.join(checkpoint_adapter_file)}." ) + # save final adapter weights save_adapter(model=model, adapter_file=args.adapter_file) print(f"Saved final adapter weights to {os.path.join(args.adapter_file)}.")