1. Add user warning for sequences over 2048 tokens in iterate_batches. (#166)

This commit is contained in:
wyanzhao 2023-12-21 06:29:31 -08:00 committed by GitHub
parent 43b6522af2
commit 22620de3ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -209,6 +209,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
for j in range(batch_size)
]
lengths = [len(x) for x in batch]
# Check if any sequence is longer than 2048 tokens
if max(lengths) > 2048:
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)