Update sublist search and calculation of input id length

This commit is contained in:
Chime Ogbuji 2024-11-06 13:57:59 -05:00 committed by Awni Hannun
parent 30fd5af843
commit 02abeeade4

View File

@ -101,34 +101,33 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]:
"""
Returns the beginning and end index of the first occurrence of small_list in big_list.
"""
for i in range(len(big_list) - len(small_list) + 1):
for j in range(len(small_list)):
if big_list[i + j] != small_list[j]:
break
else:
return i, i + len(small_list)
raise RuntimeError("Not found")
small_list_length = len(small_list)
for ind in (i for i, e in enumerate(big_list) if e == small_list[0]):
if big_list[ind : ind + small_list_length] == small_list:
return ind, ind + small_list_length - 1
def no_bos(sequence: List, bos: int) -> List:
return sequence if sequence[0] != bos else sequence[1:]
def input_and_output_lengths(
def input_length(
input_text: str, output_text: str, tokenizer: PreTrainedTokenizer
) -> Tuple[int, int]:
) -> int:
"""
Returns the length of the portion of the encoding of the concatenation of input_text and output_text
that corresponds to the input tokens and the length of the portion that corresponds to the output tokens.
that corresponds to the input tokens.
"""
message = [
{"role": "user", "content": input_text},
{"role": "assistant", "content": output_text},
]
output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id)
full_sequence = tokenizer.apply_chat_template(message, tokenize=True)
full_sequence = tokenizer.apply_chat_template(
message, add_generation_prompt=True, tokenize=True
)
output_begin, output_end = contains(output_tokens, full_sequence)
return output_begin, len(full_sequence) - len(output_tokens) + 1
return output_begin
def iterate_delineated_batches(
@ -163,17 +162,10 @@ def iterate_delineated_batches(
indices = np.random.permutation(len(batch_idx))
for i in indices:
prompt_lengths = []
completion_lengths = []
batch = []
for j in batch_idx[i]:
prompt, completion = dataset.get_prompt_and_completion(j)
prompt_length, completion_length = input_and_output_lengths(
prompt, completion, tokenizer
)
prompt_lengths.append(prompt_length)
completion_lengths.append(completion_length)
prompt_lengths.append(input_length(prompt, completion, tokenizer))
full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
if full_sequence[-1] != tokenizer.eos_token_id:
full_sequence.append(tokenizer.eos_token_id)