fix(lora): tokenizer return incompatible mx array (#271)

* fix(lora): tokenizer return incompatible encodeing mx array

* add readme nit

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Anchen
2024-01-09 19:46:38 -08:00
committed by GitHub
parent 7b258f33ac
commit 7cfda327fd
3 changed files with 5 additions and 32 deletions

View File

@@ -172,10 +172,7 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
# Collect batches from dataset
for i in range(0, len(indices) - batch_size + 1, batch_size):
# Encode batch
batch = [
tokenizer.encode(dset[indices[i + j]], eos=True)
for j in range(batch_size)
]
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
lengths = [len(x) for x in batch]
# Check if any sequence is longer than 2048 tokens
@@ -187,6 +184,7 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
# Pad to the max length
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
for j in range(batch_size):
batch_arr[j, : lengths[j]] = batch[j]
batch = mx.array(batch_arr)