diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 187467f4..9bd572e3 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -224,7 +224,7 @@ def train( n_tokens += toks.item() # Report training loss if needed - if (it + 1) % args.steps_per_report == 0: + if ((it + 1) % args.steps_per_report == 0) or (it + 1 == args.iters): train_loss = np.mean(losses) stop = time.perf_counter() @@ -259,7 +259,7 @@ def train( start = time.perf_counter() # Report validation loss if needed - if it == 0 or (it + 1) % args.steps_per_eval == 0: + if it == 0 or ((it + 1) % args.steps_per_eval == 0) or (it + 1 == args.iters): stop = time.perf_counter() val_loss = evaluate( model=model,