mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
Incorporate use of response template for completion masking
Follow example of trl's DataCollatorForCompletionOnlyLM to use response template to identify beginning of completion/continuation tokens for the purpose of masking out the other tokens during loss calculation
This commit is contained in:
parent
cb87f6f22c
commit
95e1f22812
@ -66,7 +66,7 @@ class TrainingArgs:
|
||||
)
|
||||
|
||||
|
||||
def input_masked_loss(model, inputs, input_lengths, lengths):
|
||||
def input_masked_loss(model, inputs, response_prefix_lengths, lengths):
|
||||
shifted_inputs = inputs[:, :-1]
|
||||
shifted_labels = inputs[:, 1:]
|
||||
logits = model(shifted_inputs)
|
||||
@ -75,7 +75,8 @@ def input_masked_loss(model, inputs, input_lengths, lengths):
|
||||
mask_width = shifted_inputs.shape[1]
|
||||
token_indices = mx.arange(mask_width)[None, :]
|
||||
mask = mx.logical_and(
|
||||
token_indices >= input_lengths[:, None], token_indices < lengths[:, None]
|
||||
token_indices >= response_prefix_lengths[:, None],
|
||||
token_indices < lengths[:, None],
|
||||
)
|
||||
|
||||
ce = nn.losses.cross_entropy(logits, shifted_labels) * mask
|
||||
@ -107,29 +108,6 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]:
|
||||
return ind, ind + small_list_length - 1
|
||||
|
||||
|
||||
def no_bos(sequence: List, bos: int) -> List:
|
||||
return sequence if sequence[0] != bos else sequence[1:]
|
||||
|
||||
|
||||
def input_length(
|
||||
input_text: str, output_text: str, tokenizer: PreTrainedTokenizer
|
||||
) -> int:
|
||||
"""
|
||||
Returns the length of the portion of the encoding of the concatenation of input_text and output_text
|
||||
that corresponds to the input tokens.
|
||||
"""
|
||||
message = [
|
||||
{"role": "user", "content": input_text},
|
||||
{"role": "assistant", "content": output_text},
|
||||
]
|
||||
output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id)
|
||||
full_sequence = tokenizer.apply_chat_template(
|
||||
message, add_generation_prompt=True, tokenize=True
|
||||
)
|
||||
output_begin, output_end = contains(output_tokens, full_sequence)
|
||||
return output_begin
|
||||
|
||||
|
||||
def iterate_completion_batches(
|
||||
dataset: CompletionsDataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
@ -161,15 +139,23 @@ def iterate_completion_batches(
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
prompt_lengths = []
|
||||
response_prefix_lengths = []
|
||||
batch = []
|
||||
for j in batch_idx[i]:
|
||||
prompt, completion = dataset.get_prompt_and_completion(j)
|
||||
prompt_lengths.append(input_length(prompt, completion, tokenizer))
|
||||
full_sequence = tokenizer.encode(dataset[j])
|
||||
full_sequence = dataset.get_item(j, tokenize=True)
|
||||
if full_sequence[-1] != tokenizer.eos_token_id:
|
||||
full_sequence.append(tokenizer.eos_token_id)
|
||||
batch.append(full_sequence)
|
||||
if len(dataset.response_token_ids) > 1:
|
||||
response_marker_begin, response_marker_end = contains(
|
||||
dataset.response_token_ids, full_sequence
|
||||
)
|
||||
response_prefix_lengths.append(response_marker_end + 1)
|
||||
else:
|
||||
response_marker_begin = full_sequence.index(
|
||||
dataset.response_token_ids[0]
|
||||
)
|
||||
response_prefix_lengths.append(response_marker_begin + 1)
|
||||
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
@ -186,15 +172,19 @@ def iterate_completion_batches(
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
|
||||
for j in range(batch_size // step):
|
||||
response_prefix_length = response_prefix_lengths[j]
|
||||
truncated_length = min(lengths[j], max_seq_length)
|
||||
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||
batch_arr[j, response_prefix_length:truncated_length] = batch[j][
|
||||
response_prefix_length:truncated_length
|
||||
]
|
||||
lengths[j] = (
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
|
||||
yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths)
|
||||
yield mx.array(batch_arr), mx.array(response_prefix_lengths), mx.array(
|
||||
lengths
|
||||
)
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
Loading…
Reference in New Issue
Block a user