diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 3c255703..fcc3e1d0 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -66,19 +66,25 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) ] lengths = [len(x) for x in batch] - # Check if any sequence is longer than max_seq_length if max(lengths) > max_seq_length: print( - "[WARNING] Some sequences are longer than 2048 tokens. " + f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " + f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " "Consider pre-splitting your data to save memory." ) # Pad to the max length - batch_arr = np.zeros((batch_size, max(lengths)), np.int32) + max_length_in_batch = min(max(lengths), max_seq_length) + batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) for j in range(batch_size): - batch_arr[j, : lengths[j]] = batch[j] + truncated_length = min(lengths[j], max_seq_length) + batch_arr[j, :truncated_length] = batch[j][:truncated_length] + lengths[ + j + ] = truncated_length # Update lengths to match truncated lengths batch = mx.array(batch_arr) + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) if not train: @@ -175,6 +181,7 @@ def train( tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.val_batches, + max_seq_length=args.max_seq_length, ) print( f"Iter {it + 1}: "