clean up loss function extraction (#433)

Co-authored-by: Lee Harrold <lhharrold@sep.com>
This commit is contained in:
Lee Harrold 2024-02-12 08:46:00 -05:00 committed by GitHub
parent f1ef378a58
commit 70465b8cda
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -99,7 +99,7 @@ def main(args):
for s in range(0, targets.shape[0], batch_size): for s in range(0, targets.shape[0], batch_size):
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size] bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
bx, by = map(mx.array, (bx, by)) bx, by = map(mx.array, (bx, by))
losses = loss(bx, by, reduce=False) losses = loss_fn(model, bx, by, reduce=False)
loss += mx.sum(losses).item() loss += mx.sum(losses).item()
return loss / len(targets) return loss / len(targets)
@ -131,7 +131,7 @@ def main(args):
losses = [] losses = []
tic = time.perf_counter() tic = time.perf_counter()
if (it + 1) % steps_per_eval == 0: if (it + 1) % steps_per_eval == 0:
val_loss = eval_fn(model, valid) val_loss = eval_fn(valid)
toc = time.perf_counter() toc = time.perf_counter()
print( print(
f"Iter {it + 1}: " f"Iter {it + 1}: "
@ -142,7 +142,7 @@ def main(args):
tic = time.perf_counter() tic = time.perf_counter()
if args.eval_test: if args.eval_test:
test_loss = eval_fn(model, test) test_loss = eval_fn(test)
test_ppl = math.exp(test_loss) test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.") print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")