Add input_masked loss calculation and batching w/ padding

This commit is contained in:
Chime Ogbuji 2024-06-07 12:35:07 -04:00 committed by Awni Hannun
parent f58c7de901
commit 604be3cec9

View File

@ -63,6 +63,65 @@ class TrainingArgs:
)
def iterate_input_masked_batches(
input_text, output_text, tokenizer, max_seq_length=2048
):
batch_size = len(input_text)
input_batch = [tokenizer.encode(record) for record in input_text]
output_batch = [tokenizer.encode(record) for record in output_text]
input_lengths = [len(x) for x in input_batch]
output_lengths = [len(x) for x in output_batch]
full_labels = [input_batch[idx] + output_batch[idx] for idx in range(batch_size)]
lengths = [len(x) for x in full_labels]
if max(lengths) > max_seq_length:
print(
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
"Consider pre-splitting your data to save memory."
)
pad_to = 8
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
adjusted_lengths = []
for j in range(batch_size):
input_length = input_lengths[j]
full_ids_end_idx = input_length + min(
output_lengths[j], max_length_in_batch - input_length
)
adjusted_lengths.append(full_ids_end_idx)
batch_arr[j, :full_ids_end_idx] = full_labels[j][:full_ids_end_idx]
batch = mx.array(batch_arr)
input_lengths = mx.array(input_lengths)
lengths = mx.array(adjusted_lengths)
return batch[:, :-1], input_lengths, lengths
def input_masked_loss(model, inputs, input_lengths, lengths):
shifted_inputs = inputs[:, :-1]
shifted_labels = inputs[:, 1:]
logits = model(shifted_inputs)
logits = logits.astype(mx.float32)
mask_width = shifted_inputs.shape[1]
token_indices = mx.arange(mask_width)[None, :]
mask = mx.logical_and(
token_indices >= input_lengths[:, None], token_indices < lengths[:, None]
)
ce = nn.losses.cross_entropy(logits, shifted_labels) * mask
ntoks = mask.sum()
ce = ce.sum() / ntoks
return ce, ntoks
def default_loss(model, inputs, targets, lengths):
logits = model(inputs)
logits = logits.astype(mx.float32)