From f51e98fcf1147bac049dfd18da699ea2313c7ad9 Mon Sep 17 00:00:00 2001
From: Anchen
Date: Fri, 26 Jan 2024 07:38:04 +1100
Subject: [PATCH] chore(mlx-lm): truncate the input sentence to max seq len in
lora iterate_batches (#373)
* chore(mlx-lm): pass max seq len to evaluate in training loop
* chore: make sure the batch seq not exceed max len
* chore: update comment
* chore: add warning before truncate input
---
llms/mlx_lm/tuner/trainer.py | 15 +++++++++++----
1 file changed, 11 insertions(+), 4 deletions(-)
diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py
index 3c255703..fcc3e1d0 100644
--- a/llms/mlx_lm/tuner/trainer.py
+++ b/llms/mlx_lm/tuner/trainer.py
@@ -66,19 +66,25 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False)
]
lengths = [len(x) for x in batch]
- # Check if any sequence is longer than max_seq_length
if max(lengths) > max_seq_length:
print(
- "[WARNING] Some sequences are longer than 2048 tokens. "
+ f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
+ f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
)
# Pad to the max length
- batch_arr = np.zeros((batch_size, max(lengths)), np.int32)
+ max_length_in_batch = min(max(lengths), max_seq_length)
+ batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
for j in range(batch_size):
- batch_arr[j, : lengths[j]] = batch[j]
+ truncated_length = min(lengths[j], max_seq_length)
+ batch_arr[j, :truncated_length] = batch[j][:truncated_length]
+ lengths[
+ j
+ ] = truncated_length # Update lengths to match truncated lengths
batch = mx.array(batch_arr)
+
yield batch[:, :-1], batch[:, 1:], mx.array(lengths)
if not train:
@@ -175,6 +181,7 @@ def train(
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
+ max_seq_length=args.max_seq_length,
)
print(
f"Iter {it + 1}: "