1. Add user warning for sequences over 2048 tokens in iterate_batches. (#166)

This commit is contained in:
wyanzhao
2023-12-21 06:29:31 -08:00
committed by GitHub
parent 58f409feb0
commit 18f0a96cee

View File

@@ -209,6 +209,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
for j in range(batch_size)
]
lengths = [len(x) for x in batch]
# Check if any sequence is longer than 2048 tokens
if max(lengths) > 2048:
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)