diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index c0fb5536..3c8a4a0f 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -121,18 +121,11 @@ def evaluate( class TrainingCallback: - def on_train_loss_report( - self, - steps: int, - loss: float, - it_sec: float, - tokens_sec: float, - trained_tokens: int, - ): + def on_train_loss_report(self, train_info: dict): """Called to report training loss at specified intervals.""" pass - def on_val_loss_report(self, steps: int, loss: float, val_time: float): + def on_val_loss_report(self, val_info: dict): """Called to report validation loss at specified intervals or the beginning.""" pass @@ -146,7 +139,7 @@ def train( args: TrainingArgs = TrainingArgs(), loss: callable = default_loss, iterate_batches: callable = iterate_batches, - training_callback=None, + training_callback: TrainingCallback = None, ): print(f"Starting training..., iters: {args.iters}") @@ -189,20 +182,28 @@ def train( train_loss = np.mean(losses) stop = time.perf_counter() + learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens print( f"Iter {it + 1}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " f"Trained Tokens {trained_tokens}" ) if training_callback is not None: - training_callback.on_train_loss_report( - it + 1, train_loss, it_sec, tokens_sec, trained_tokens - ) + train_info = { + "iteration": it + 1, + "train_loss": train_loss, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + } + training_callback.on_train_loss_report(train_info) losses = [] n_tokens = 0 @@ -229,7 +230,12 @@ def train( ) if training_callback is not None: - training_callback.on_val_loss_report(it + 1, val_loss, val_time) + val_info = { + "iteration": it + 1, + "val_loss": val_loss, + "val_time": val_time + } + training_callback.on_val_loss_report(val_info) start = time.perf_counter()