bug fix in transformer_lm example

This commit is contained in:
Leon Ericsson 2023-12-06 17:42:23 +01:00 committed by Angelos Katharopoulos
parent 8b965b2e33
commit e488831e03

View File

@ -81,13 +81,13 @@ def main(args):
optimizer = optim.SGD(learning_rate=args.learning_rate)
loss_and_grad_fn = nn.value_and_grad(model, model.loss)
def eval_fn(params, dataset):
def eval_fn(model, dataset):
inputs, targets = map(mx.array, to_samples(context_size, dataset))
loss = 0
for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by))
losses = self.loss(bx, by, reduce=False)
losses = model.loss(bx, by, reduce=False)
loss += mx.sum(losses).item()
return loss / len(targets)
@ -110,9 +110,8 @@ def main(args):
)
losses = []
tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(params, valid)
val_loss = eval_fn(model, valid)
toc = time.perf_counter()
print(
f"Iter {it + 1}: "
@ -123,7 +122,7 @@ def main(args):
tic = time.perf_counter()
if args.eval_test:
test_loss = eval_fn(params, test)
test_loss = eval_fn(model, test)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")