mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-02 22:04:53 +08:00
fix encoding with special tokens + chat template (#1189)
This commit is contained in:
@@ -100,14 +100,8 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
# Encode batch
|
||||
batch = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||
for b in batch:
|
||||
if b[-1] != tokenizer.eos_token_id:
|
||||
b.append(tokenizer.eos_token_id)
|
||||
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
|
Reference in New Issue
Block a user