Incorporate use of response template for completion masking

Follow example of trl's DataCollatorForCompletionOnlyLM to use response template to identify beginning of completion/continuation tokens for the purpose of masking out the other tokens during loss calculation
This commit is contained in:
Chime Ogbuji 2024-12-08 11:26:50 -05:00 committed by Awni Hannun
parent cb87f6f22c
commit 95e1f22812

View File

@ -66,7 +66,7 @@ class TrainingArgs:
)
def input_masked_loss(model, inputs, input_lengths, lengths):
def input_masked_loss(model, inputs, response_prefix_lengths, lengths):
shifted_inputs = inputs[:, :-1]
shifted_labels = inputs[:, 1:]
logits = model(shifted_inputs)
@ -75,7 +75,8 @@ def input_masked_loss(model, inputs, input_lengths, lengths):
mask_width = shifted_inputs.shape[1]
token_indices = mx.arange(mask_width)[None, :]
mask = mx.logical_and(
token_indices >= input_lengths[:, None], token_indices < lengths[:, None]
token_indices >= response_prefix_lengths[:, None],
token_indices < lengths[:, None],
)
ce = nn.losses.cross_entropy(logits, shifted_labels) * mask
@ -107,29 +108,6 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]:
return ind, ind + small_list_length - 1
def no_bos(sequence: List, bos: int) -> List:
return sequence if sequence[0] != bos else sequence[1:]
def input_length(
input_text: str, output_text: str, tokenizer: PreTrainedTokenizer
) -> int:
"""
Returns the length of the portion of the encoding of the concatenation of input_text and output_text
that corresponds to the input tokens.
"""
message = [
{"role": "user", "content": input_text},
{"role": "assistant", "content": output_text},
]
output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id)
full_sequence = tokenizer.apply_chat_template(
message, add_generation_prompt=True, tokenize=True
)
output_begin, output_end = contains(output_tokens, full_sequence)
return output_begin
def iterate_completion_batches(
dataset: CompletionsDataset,
tokenizer: PreTrainedTokenizer,
@ -161,15 +139,23 @@ def iterate_completion_batches(
while True:
indices = np.random.permutation(len(batch_idx))
for i in indices:
prompt_lengths = []
response_prefix_lengths = []
batch = []
for j in batch_idx[i]:
prompt, completion = dataset.get_prompt_and_completion(j)
prompt_lengths.append(input_length(prompt, completion, tokenizer))
full_sequence = tokenizer.encode(dataset[j])
full_sequence = dataset.get_item(j, tokenize=True)
if full_sequence[-1] != tokenizer.eos_token_id:
full_sequence.append(tokenizer.eos_token_id)
batch.append(full_sequence)
if len(dataset.response_token_ids) > 1:
response_marker_begin, response_marker_end = contains(
dataset.response_token_ids, full_sequence
)
response_prefix_lengths.append(response_marker_end + 1)
else:
response_marker_begin = full_sequence.index(
dataset.response_token_ids[0]
)
response_prefix_lengths.append(response_marker_begin + 1)
lengths = [len(x) for x in batch]
@ -186,15 +172,19 @@ def iterate_completion_batches(
max_length_in_batch = min(max_length_in_batch, max_seq_length)
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
for j in range(batch_size // step):
response_prefix_length = response_prefix_lengths[j]
truncated_length = min(lengths[j], max_seq_length)
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
batch_arr[j, response_prefix_length:truncated_length] = batch[j][
response_prefix_length:truncated_length
]
lengths[j] = (
truncated_length # Update lengths to match truncated lengths
)
yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths)
yield mx.array(batch_arr), mx.array(response_prefix_lengths), mx.array(
lengths
)
if not train:
break