fix encoding with special tokens + chat template (#1189)

This commit is contained in:
Awni Hannun
2025-01-03 10:50:59 -08:00
committed by GitHub
parent 3a58c36109
commit c4833a2f55
13 changed files with 95 additions and 97 deletions

View File

@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
# Encode batch
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
for b in batch:
if b[-1] != tokenizer.eos_token_id:
b.append(tokenizer.eos_token_id)
batch = [dataset[j] for j in batch_idx[i]]
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "