Replace iterate_input_masked_batches with iterate_delineated_batches, an updated attempt to better sync with iterate_batches logic

This commit is contained in:
Chime Ogbuji 2024-11-05 15:17:23 -05:00 committed by Awni Hannun
parent 604be3cec9
commit 79a042768f
2 changed files with 118 additions and 42 deletions

View File

@ -74,6 +74,9 @@ class CompletionsDataset:
for d in data
]
def get_prompt_and_completion(self, idx: int):
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
def __getitem__(self, idx: int):
return self._data[idx]

View File

@ -5,13 +5,16 @@ import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
from typing import List, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from transformers import PreTrainedTokenizer
from .datasets import CompletionsDataset
def grad_checkpoint(layer):
@ -63,47 +66,6 @@ class TrainingArgs:
)
def iterate_input_masked_batches(
input_text, output_text, tokenizer, max_seq_length=2048
):
batch_size = len(input_text)
input_batch = [tokenizer.encode(record) for record in input_text]
output_batch = [tokenizer.encode(record) for record in output_text]
input_lengths = [len(x) for x in input_batch]
output_lengths = [len(x) for x in output_batch]
full_labels = [input_batch[idx] + output_batch[idx] for idx in range(batch_size)]
lengths = [len(x) for x in full_labels]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
)
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
adjusted_lengths = []
for j in range(batch_size):
input_length = input_lengths[j]
full_ids_end_idx = input_length + min(
output_lengths[j], max_length_in_batch - input_length
)
adjusted_lengths.append(full_ids_end_idx)
batch_arr[j, :full_ids_end_idx] = full_labels[j][:full_ids_end_idx]
batch = mx.array(batch_arr)
input_lengths = mx.array(input_lengths)
lengths = mx.array(adjusted_lengths)
return batch[:, :-1], input_lengths, lengths
def input_masked_loss(model, inputs, input_lengths, lengths):
shifted_inputs = inputs[:, :-1]
shifted_labels = inputs[:, 1:]
@ -135,6 +97,117 @@ def default_loss(model, inputs, targets, lengths):
return ce, ntoks
def contains(small_list: List, big_list: List) -> Tuple[int, int]:
"""
Returns the beginning and end index of the first occurrence of small_list in big_list.
"""
for i in range(len(big_list) - len(small_list) + 1):
for j in range(len(small_list)):
if big_list[i + j] != small_list[j]:
break
else:
return i, i + len(small_list)
raise RuntimeError("Not found")
def no_bos(sequence: List, bos: int) -> List:
return sequence if sequence[0] != bos else sequence[1:]
def input_and_output_lengths(
input_text: str, output_text: str, tokenizer: PreTrainedTokenizer
) -> Tuple[int, int]:
"""
Returns the length of the portion of the encoding of the concatenation of input_text and output_text
that corresponds to the input tokens and the length of the portion that corresponds to the output tokens.
"""
message = [
{"role": "user", "content": input_text},
{"role": "assistant", "content": output_text},
]
output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id)
full_sequence = tokenizer.apply_chat_template(message, tokenize=True)
output_begin, output_end = contains(output_tokens, full_sequence)
return output_begin, len(full_sequence) - len(output_tokens) + 1
def iterate_delineated_batches(
dataset: CompletionsDataset,
tokenizer: PreTrainedTokenizer,
batch_size: int,
max_seq_length: int,
train: bool = False,
):
"""
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
(using create_delineated_batches), and returns the lengths of input tokens as well as the full sequences.
"""
idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
# If running in distributed mode (N machines) then each one should skip N-1
# samples
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Make the batches:
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
prompt_lengths = []
completion_lengths = []
batch = []
for j in batch_idx[i]:
prompt, completion = dataset.get_prompt_and_completion(j)
prompt_length, completion_length = input_and_output_lengths(
prompt, prompt, tokenizer
)
prompt_lengths.append(prompt_length)
completion_lengths.append(completion_length)
full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
if full_sequence[-1] != tokenizer.eos_token_id:
full_sequence.append(tokenizer.eos_token_id)
batch.append(full_sequence)
lengths = [len(x) for x in batch]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
)
# Pad to the nearest multiple of 8 or the maximum length
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size // step):
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
lengths[j] = (
truncated_length # Update lengths to match truncated lengths
)
yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths)
if not train:
break
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
# Sort by length:
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))