mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 18:51:18 +08:00
1. Add user warning for sequences over 2048 tokens in iterate_batches. (#166)
This commit is contained in:
parent
43b6522af2
commit
22620de3ee
@ -210,6 +210,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
|||||||
]
|
]
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
|
# Check if any sequence is longer than 2048 tokens
|
||||||
|
if max(lengths) > 2048:
|
||||||
|
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
|
||||||
|
|
||||||
# Pad to the max length
|
# Pad to the max length
|
||||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||||
for j in range(batch_size):
|
for j in range(batch_size):
|
||||||
|
Loading…
Reference in New Issue
Block a user