mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
1. Add user warning for sequences over 2048 tokens in iterate_batches. (#166)
This commit is contained in:
parent
43b6522af2
commit
22620de3ee
@ -209,6 +209,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
||||
for j in range(batch_size)
|
||||
]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
# Check if any sequence is longer than 2048 tokens
|
||||
if max(lengths) > 2048:
|
||||
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
|
||||
|
||||
# Pad to the max length
|
||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||
|
Loading…
Reference in New Issue
Block a user