From 70465b8cda77c59d5689560b52503a55c49b9534 Mon Sep 17 00:00:00 2001 From: Lee Harrold <35541778+Harrolee@users.noreply.github.com> Date: Mon, 12 Feb 2024 08:46:00 -0500 Subject: [PATCH] clean up loss function extraction (#433) Co-authored-by: Lee Harrold --- transformer_lm/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/transformer_lm/main.py b/transformer_lm/main.py index 214518d1..e5ec37dd 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -99,7 +99,7 @@ def main(args): for s in range(0, targets.shape[0], batch_size): bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] bx, by = map(mx.array, (bx, by)) - losses = loss(bx, by, reduce=False) + losses = loss_fn(model, bx, by, reduce=False) loss += mx.sum(losses).item() return loss / len(targets) @@ -131,7 +131,7 @@ def main(args): losses = [] tic = time.perf_counter() if (it + 1) % steps_per_eval == 0: - val_loss = eval_fn(model, valid) + val_loss = eval_fn(valid) toc = time.perf_counter() print( f"Iter {it + 1}: " @@ -142,7 +142,7 @@ def main(args): tic = time.perf_counter() if args.eval_test: - test_loss = eval_fn(model, test) + test_loss = eval_fn(test) test_ppl = math.exp(test_loss) print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")