diff --git a/lora/lora.py b/lora/lora.py index a90eda70..723e783d 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -220,8 +220,12 @@ def iterate_batches(dset, tokenizer, batch_size, train=False): def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches): all_losses = [] ntokens = 0 + + # num_batches can be -1 to indicate the entire set + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + for it, batch in zip( - range(num_batches), + index_iterator, iterate_batches(dataset, tokenizer, batch_size), ): losses, toks = loss(model, *batch)