mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-23 22:18:06 +08:00
fix(lora): tokenizer return incompatible mx array (#271)
* fix(lora): tokenizer return incompatible encodeing mx array * add readme nit --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
@@ -172,10 +172,7 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
||||
# Collect batches from dataset
|
||||
for i in range(0, len(indices) - batch_size + 1, batch_size):
|
||||
# Encode batch
|
||||
batch = [
|
||||
tokenizer.encode(dset[indices[i + j]], eos=True)
|
||||
for j in range(batch_size)
|
||||
]
|
||||
batch = [tokenizer.encode(dset[indices[i + j]]) for j in range(batch_size)]
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
# Check if any sequence is longer than 2048 tokens
|
||||
@@ -187,6 +184,7 @@ def iterate_batches(dset, tokenizer, batch_size, train=False):
|
||||
|
||||
# Pad to the max length
|
||||
batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
|
||||
|
||||
for j in range(batch_size):
|
||||
batch_arr[j, : lengths[j]] = batch[j]
|
||||
batch = mx.array(batch_arr)
|
||||
|
Reference in New Issue
Block a user