From 604be3cec9eec48664f1c175e4e3af763c054930 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Fri, 7 Jun 2024 12:35:07 -0400 Subject: [PATCH] Add input_masked loss calculation and batching w/ padding --- llms/mlx_lm/tuner/trainer.py | 59 ++++++++++++++++++++++++++++++++++++ 1 file changed, 59 insertions(+) diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index bf84d066..e8e6ea6a 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -63,6 +63,65 @@ class TrainingArgs: ) +def iterate_input_masked_batches( + input_text, output_text, tokenizer, max_seq_length=2048 +): + batch_size = len(input_text) + + input_batch = [tokenizer.encode(record) for record in input_text] + output_batch = [tokenizer.encode(record) for record in output_text] + + input_lengths = [len(x) for x in input_batch] + output_lengths = [len(x) for x in output_batch] + + full_labels = [input_batch[idx] + output_batch[idx] for idx in range(batch_size)] + lengths = [len(x) for x in full_labels] + + if max(lengths) > max_seq_length: + print( + f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " + f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " + "Consider pre-splitting your data to save memory." + ) + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + + batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) + adjusted_lengths = [] + for j in range(batch_size): + input_length = input_lengths[j] + full_ids_end_idx = input_length + min( + output_lengths[j], max_length_in_batch - input_length + ) + adjusted_lengths.append(full_ids_end_idx) + batch_arr[j, :full_ids_end_idx] = full_labels[j][:full_ids_end_idx] + + batch = mx.array(batch_arr) + input_lengths = mx.array(input_lengths) + lengths = mx.array(adjusted_lengths) + + return batch[:, :-1], input_lengths, lengths + + +def input_masked_loss(model, inputs, input_lengths, lengths): + shifted_inputs = inputs[:, :-1] + shifted_labels = inputs[:, 1:] + logits = model(shifted_inputs) + logits = logits.astype(mx.float32) + + mask_width = shifted_inputs.shape[1] + token_indices = mx.arange(mask_width)[None, :] + mask = mx.logical_and( + token_indices >= input_lengths[:, None], token_indices < lengths[:, None] + ) + + ce = nn.losses.cross_entropy(logits, shifted_labels) * mask + ntoks = mask.sum() + ce = ce.sum() / ntoks + return ce, ntoks + + def default_loss(model, inputs, targets, lengths): logits = model(inputs) logits = logits.astype(mx.float32)