mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-09 02:22:26 +08:00
1. Add user warning for sequences over 2048 tokens in iterate_batches.
This commit is contained in:
@@ -209,6 +209,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
|||||||
for j in range(batch_size)
|
for j in range(batch_size)
|
||||||
]
|
]
|
||||||
lengths = [len(x) for x in batch]
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
|
# Check if any sequence is longer than 2048 tokens
|
||||||
|
if max(lengths) > 2048:
|
||||||
|
print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.")
|
||||||
|
|
||||||
# Pad to the max length
|
# Pad to the max length
|
||||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||||
|
Reference in New Issue
Block a user