mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Validation with full data set, results in NaN validation score (#879)
* CLI arguments may set num_batches to -1 The CLI arguments allow you to validate with the entire dataset by passing a negative one value, but this quickly results in a division by zero `NaN` to appear as the validation score! * Must properly assemble the mini batches when validating with entire dataset. Tested locally, a validation of a novel took about an hour, with a loss of 0.928. Thanks @awni for the correction! * Set up the pre-commit hooks and run them so that black may format lora.py.
This commit is contained in:
parent
63800c8feb
commit
9717307ff0
@ -220,8 +220,12 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
||||
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
|
||||
all_losses = []
|
||||
ntokens = 0
|
||||
|
||||
# num_batches can be -1 to indicate the entire set
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
for it, batch in zip(
|
||||
range(num_batches),
|
||||
index_iterator,
|
||||
iterate_batches(dataset, tokenizer, batch_size),
|
||||
):
|
||||
losses, toks = loss(model, *batch)
|
||||
|
Loading…
Reference in New Issue
Block a user