From 42413c5d851668abbc7f919107f06cf92b4e153d Mon Sep 17 00:00:00 2001 From: Awni Hannun Date: Tue, 11 Feb 2025 16:48:55 -0800 Subject: [PATCH] fix lora timings after validation (#1278) --- llms/mlx_lm/tuner/trainer.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index d675f9b6..64e26af8 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -233,8 +233,8 @@ def train( n_tokens = 0 steps = 0 trained_tokens = 0 + train_time = 0 # Main training loop - start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_batches( @@ -245,10 +245,11 @@ def train( train=True, ), ): + tic = time.perf_counter() # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - stop = time.perf_counter() + tic = time.perf_counter() val_loss = evaluate( model=model, dataset=val_dataset, @@ -259,7 +260,7 @@ def train( max_seq_length=args.max_seq_length, iterate_batches=iterate_batches, ) - val_time = time.perf_counter() - stop + val_time = time.perf_counter() - tic if rank == 0: print( f"Iter {it}: " @@ -276,24 +277,23 @@ def train( } training_callback.on_val_loss_report(val_info) - start = time.perf_counter() + tic = time.perf_counter() lvalue, toks = step(batch) losses += lvalue n_tokens += toks steps += 1 mx.eval(state, losses, n_tokens) + train_time += time.perf_counter() - tic # Report training loss if needed if it % args.steps_per_report == 0 or it == args.iters: - stop = time.perf_counter() - train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / (stop - start) - tokens_sec = float(n_tokens) / (stop - start) + it_sec = args.steps_per_report / train_time + tokens_sec = float(n_tokens) / train_time trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: @@ -322,7 +322,7 @@ def train( losses = 0 n_tokens = 0 steps = 0 - start = time.perf_counter() + train_time = 0 # Save adapter weights if it % args.steps_per_save == 0: