mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Validation with full data set, results in NaN validation score (#879)
* CLI arguments may set num_batches to -1 The CLI arguments allow you to validate with the entire dataset by passing a negative one value, but this quickly results in a division by zero `NaN` to appear as the validation score! * Must properly assemble the mini batches when validating with entire dataset. Tested locally, a validation of a novel took about an hour, with a loss of 0.928. Thanks @awni for the correction! * Set up the pre-commit hooks and run them so that black may format lora.py.
This commit is contained in:
parent
63800c8feb
commit
9717307ff0
@ -220,8 +220,12 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
|||||||
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
|
def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches):
|
||||||
all_losses = []
|
all_losses = []
|
||||||
ntokens = 0
|
ntokens = 0
|
||||||
|
|
||||||
|
# num_batches can be -1 to indicate the entire set
|
||||||
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||||
|
|
||||||
for it, batch in zip(
|
for it, batch in zip(
|
||||||
range(num_batches),
|
index_iterator,
|
||||||
iterate_batches(dataset, tokenizer, batch_size),
|
iterate_batches(dataset, tokenizer, batch_size),
|
||||||
):
|
):
|
||||||
losses, toks = loss(model, *batch)
|
losses, toks = loss(model, *batch)
|
||||||
|
Loading…
Reference in New Issue
Block a user