From 9717307ff0f08e3dc683ac68c2e3851f1a5184cd Mon Sep 17 00:00:00 2001 From: James A Capozzoli <157492257+jac-jim@users.noreply.github.com> Date: Wed, 10 Jul 2024 11:36:11 -0400 Subject: [PATCH] Validation with full data set, results in NaN validation score (#879) * CLI arguments may set num_batches to -1 The CLI arguments allow you to validate with the entire dataset by passing a negative one value, but this quickly results in a division by zero `NaN` to appear as the validation score! * Must properly assemble the mini batches when validating with entire dataset. Tested locally, a validation of a novel took about an hour, with a loss of 0.928. Thanks @awni for the correction! * Set up the pre-commit hooks and run them so that black may format lora.py. --- lora/lora.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/lora/lora.py b/lora/lora.py index a90eda70..723e783d 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -220,8 +220,12 @@ def iterate_batches(dset, tokenizer, batch_size, train=False): def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches): all_losses = [] ntokens = 0 + + # num_batches can be -1 to indicate the entire set + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + for it, batch in zip( - range(num_batches), + index_iterator, iterate_batches(dataset, tokenizer, batch_size), ): losses, toks = loss(model, *batch)