1. Add user warning for sequences over 2048 tokens in iterate_batches.

This commit is contained in:
wyanzhao
2023-12-20 23:35:16 -08:00
parent 3efb1cc2cc
commit d8a4920e66

View File

@@ -209,6 +209,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
for j in range(batch_size)
]
lengths = [len(x) for x in batch]
# Check if any sequence is longer than 2048 tokens
if max(lengths) > 2048:
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)