From 22620de3ee17da833a958611a2209fb8f07342e7 Mon Sep 17 00:00:00 2001 From: wyanzhao Date: Thu, 21 Dec 2023 06:29:31 -0800 Subject: [PATCH] 1. Add user warning for sequences over 2048 tokens in iterate_batches. (#166) --- lora/lora.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/lora/lora.py b/lora/lora.py index 875b51ef..718a27c9 100644 --- a/lora/lora.py +++ b/lora/lora.py @@ -209,6 +209,10 @@ def iterate_batches(dset, tokenizer, batch_size, train=False): for j in range(batch_size) ] lengths = [len(x) for x in batch] + + # Check if any sequence is longer than 2048 tokens + if max(lengths) > 2048: + print("Warning: Some sequences are longer than 2048 tokens. Consider pre-splitting your data to save memory.") # Pad to the max length batch_arr = np.zeros((batch_size, max(lengths)), np.int32)