From 8eee4399f4b2c8c071da90a8b10a71ee6e14292a Mon Sep 17 00:00:00 2001 From: Madroid Ma Date: Wed, 21 Feb 2024 05:07:21 +0800 Subject: [PATCH] LoRA: Add printing and callbacks for learning rate during training (#457) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * LoRA:Refactor TrainingCallback to enhance flexibility and extensibility This commit refactors the TrainingCallback class to accept a dictionary parameter for both on_train_loss_report and on_val_loss_report methods. By switching from multiple parameters to a single dict parameter, this change significantly improves the class's flexibility and makes it easier to extend with new training or validation metrics in the future without altering the method signatures. This approach simplifies the addition of new information to be logged or processed and aligns with best practices for scalable and maintainable code design. * LoRA: Add printing and callbacks for learning rate during training --- llms/mlx_lm/tuner/trainer.py | 34 ++++++++++++++++++++-------------- 1 file changed, 20 insertions(+), 14 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index c0fb5536..3c8a4a0f 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -121,18 +121,11 @@ def evaluate( class TrainingCallback: - def on_train_loss_report( - self, - steps: int, - loss: float, - it_sec: float, - tokens_sec: float, - trained_tokens: int, - ): + def on_train_loss_report(self, train_info: dict): """Called to report training loss at specified intervals.""" pass - def on_val_loss_report(self, steps: int, loss: float, val_time: float): + def on_val_loss_report(self, val_info: dict): """Called to report validation loss at specified intervals or the beginning.""" pass @@ -146,7 +139,7 @@ def train( args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, iterate_batches: callable = iterate_batches, - training_callback=None, + training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") @@ -189,20 +182,28 @@ def train( train_loss = np.mean(losses) stop = time.perf_counter() + learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens print( f"Iter {it + 1}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " f"Trained Tokens {trained_tokens}" ) if training_callback is not None: - training_callback.on_train_loss_report( - it + 1, train_loss, it_sec, tokens_sec, trained_tokens - ) + train_info = { + "iteration": it + 1, + "train_loss": train_loss, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + } + training_callback.on_train_loss_report(train_info) losses = [] n_tokens = 0 @@ -229,7 +230,12 @@ def train( ) if training_callback is not None: - training_callback.on_val_loss_report(it + 1, val_loss, val_time) + val_info = { + "iteration": it + 1, + "val_loss": val_loss, + "val_time": val_time + } + training_callback.on_val_loss_report(val_info) start = time.perf_counter()