mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-09 10:26:38 +08:00
clean up loss function extraction (#433)
Co-authored-by: Lee Harrold <lhharrold@sep.com>
This commit is contained in:
parent
f1ef378a58
commit
70465b8cda
@ -99,7 +99,7 @@ def main(args):
|
||||
for s in range(0, targets.shape[0], batch_size):
|
||||
bx, by = inputs[s : s + batch_size], targets[s : s + batch_size]
|
||||
bx, by = map(mx.array, (bx, by))
|
||||
losses = loss(bx, by, reduce=False)
|
||||
losses = loss_fn(model, bx, by, reduce=False)
|
||||
loss += mx.sum(losses).item()
|
||||
return loss / len(targets)
|
||||
|
||||
@ -131,7 +131,7 @@ def main(args):
|
||||
losses = []
|
||||
tic = time.perf_counter()
|
||||
if (it + 1) % steps_per_eval == 0:
|
||||
val_loss = eval_fn(model, valid)
|
||||
val_loss = eval_fn(valid)
|
||||
toc = time.perf_counter()
|
||||
print(
|
||||
f"Iter {it + 1}: "
|
||||
@ -142,7 +142,7 @@ def main(args):
|
||||
tic = time.perf_counter()
|
||||
|
||||
if args.eval_test:
|
||||
test_loss = eval_fn(model, test)
|
||||
test_loss = eval_fn(test)
|
||||
test_ppl = math.exp(test_loss)
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user