LoRA: report last train info (#595)

This commit is contained in:
madroid
2024-03-20 08:29:50 +08:00
committed by GitHub
parent 4680ef4413
commit 39d5ca6427

View File

@@ -224,7 +224,7 @@ def train(
n_tokens += toks.item()
# Report training loss if needed
if (it + 1) % args.steps_per_report == 0:
if ((it + 1) % args.steps_per_report == 0) or (it + 1 == args.iters):
train_loss = np.mean(losses)
stop = time.perf_counter()
@@ -259,7 +259,7 @@ def train(
start = time.perf_counter()
# Report validation loss if needed
if it == 0 or (it + 1) % args.steps_per_eval == 0:
if it == 0 or ((it + 1) % args.steps_per_eval == 0) or (it + 1 == args.iters):
stop = time.perf_counter()
val_loss = evaluate(
model=model,