Validation with full data set, results in NaN validation score (#879)

* CLI arguments may set num_batches to -1

The CLI arguments allow you to validate with the entire dataset by passing a negative one value, but this quickly results in a division by zero `NaN` to appear as the validation score!

* Must properly assemble the mini batches when validating with entire dataset.

Tested locally, a validation of a novel took about an hour, with a loss of 0.928. Thanks @awni for the correction!

* Set up the pre-commit hooks and run them so that black may format lora.py.
This commit is contained in:
James A Capozzoli 2024-07-10 11:36:11 -04:00 committed by GitHub
parent 63800c8feb
commit 9717307ff0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -220,8 +220,12 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches): def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
all_losses = [] all_losses = []
ntokens = 0 ntokens = 0
# num_batches can be -1 to indicate the entire set
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for it, batch in zip( for it, batch in zip(
range(num_batches), index_iterator,
iterate_batches(dataset, tokenizer, batch_size), iterate_batches(dataset, tokenizer, batch_size),
): ):
losses, toks = loss(model, *batch) losses, toks = loss(model, *batch)