LoRA: report last train info (#595)

This commit is contained in:
madroid 2024-03-20 08:29:50 +08:00 committed by GitHub
parent 4680ef4413
commit 39d5ca6427
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -224,7 +224,7 @@ def train(
n_tokens += toks.item() n_tokens += toks.item()
# Report training loss if needed # Report training loss if needed
if (it + 1) % args.steps_per_report == 0: if ((it + 1) % args.steps_per_report == 0) or (it + 1 == args.iters):
train_loss = np.mean(losses) train_loss = np.mean(losses)
stop = time.perf_counter() stop = time.perf_counter()
@ -259,7 +259,7 @@ def train(
start = time.perf_counter() start = time.perf_counter()
# Report validation loss if needed # Report validation loss if needed
if it == 0 or (it + 1) % args.steps_per_eval == 0: if it == 0 or ((it + 1) % args.steps_per_eval == 0) or (it + 1 == args.iters):
stop = time.perf_counter() stop = time.perf_counter()
val_loss = evaluate( val_loss = evaluate(
model=model, model=model,