From d8a4920e66c34bd9fa1f9d4c631bb8a64d2de777 Mon Sep 17 00:00:00 2001 From: wyanzhao Date: Wed, 20 Dec 2023 23:35:16 -0800 Subject: [PATCH] 1. Add user warning for sequences over 2048 tokens in iterate_batches. --- lora/lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lora/lora.py b/lora/lora.py index 875b51ef..718a27c9 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -209,6 +209,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False): for j in range(batch_size) ] lengths = [len(x) for x in batch] + + # Check if any sequence is longer than 2048 tokens + if max(lengths) > 2048: + print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.") # Pad to the max length batch_arr = np.zeros((batch_size, max(lengths)), np.int32)