From e488831e0366f439abb6f0f98f6df5748cb38c03 Mon Sep 17 00:00:00 2001 From: Leon Ericsson Date: Wed, 6 Dec 2023 17:42:23 +0100 Subject: [PATCH] bug fix in transformer_lm example --- transformer_lm/main.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/transformer_lm/main.py b/transformer_lm/main.py index f94003fd..b0d2966b 100644 --- a/transformer_lm/main.py +++ b/transformer_lm/main.py @@ -81,13 +81,13 @@ def main(args): optimizer = optim.SGD(learning_rate=args.learning_rate) loss_and_grad_fn = nn.value_and_grad(model, model.loss) - def eval_fn(params, dataset): + def eval_fn(model, dataset): inputs, targets = map(mx.array, to_samples(context_size, dataset)) loss = 0 for s in range(0, targets.shape[0], batch_size): bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] bx, by = map(mx.array, (bx, by)) - losses = self.loss(bx, by, reduce=False) + losses = model.loss(bx, by, reduce=False) loss += mx.sum(losses).item() return loss / len(targets) @@ -110,9 +110,8 @@ def main(args): ) losses = [] tic = time.perf_counter() - if (it + 1) % steps_per_eval == 0: - val_loss = eval_fn(params, valid) + val_loss = eval_fn(model, valid) toc = time.perf_counter() print( f"Iter {it + 1}: " @@ -123,7 +122,7 @@ def main(args): tic = time.perf_counter() if args.eval_test: - test_loss = eval_fn(params, test) + test_loss = eval_fn(model, test) test_ppl = math.exp(test_loss) print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")