mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-28 00:30:09 +08:00
Update sublist search and calculation of input id length
This commit is contained in:
parent
30fd5af843
commit
02abeeade4
@ -101,34 +101,33 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the beginning and end index of the first occurrence of small_list in big_list.
|
||||
"""
|
||||
for i in range(len(big_list) - len(small_list) + 1):
|
||||
for j in range(len(small_list)):
|
||||
if big_list[i + j] != small_list[j]:
|
||||
break
|
||||
else:
|
||||
return i, i + len(small_list)
|
||||
raise RuntimeError("Not found")
|
||||
small_list_length = len(small_list)
|
||||
for ind in (i for i, e in enumerate(big_list) if e == small_list[0]):
|
||||
if big_list[ind : ind + small_list_length] == small_list:
|
||||
return ind, ind + small_list_length - 1
|
||||
|
||||
|
||||
def no_bos(sequence: List, bos: int) -> List:
|
||||
return sequence if sequence[0] != bos else sequence[1:]
|
||||
|
||||
|
||||
def input_and_output_lengths(
|
||||
def input_length(
|
||||
input_text: str, output_text: str, tokenizer: PreTrainedTokenizer
|
||||
) -> Tuple[int, int]:
|
||||
) -> int:
|
||||
"""
|
||||
Returns the length of the portion of the encoding of the concatenation of input_text and output_text
|
||||
that corresponds to the input tokens and the length of the portion that corresponds to the output tokens.
|
||||
that corresponds to the input tokens.
|
||||
"""
|
||||
message = [
|
||||
{"role": "user", "content": input_text},
|
||||
{"role": "assistant", "content": output_text},
|
||||
]
|
||||
output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id)
|
||||
full_sequence = tokenizer.apply_chat_template(message, tokenize=True)
|
||||
full_sequence = tokenizer.apply_chat_template(
|
||||
message, add_generation_prompt=True, tokenize=True
|
||||
)
|
||||
output_begin, output_end = contains(output_tokens, full_sequence)
|
||||
return output_begin, len(full_sequence) - len(output_tokens) + 1
|
||||
return output_begin
|
||||
|
||||
|
||||
def iterate_delineated_batches(
|
||||
@ -163,17 +162,10 @@ def iterate_delineated_batches(
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
prompt_lengths = []
|
||||
completion_lengths = []
|
||||
batch = []
|
||||
for j in batch_idx[i]:
|
||||
prompt, completion = dataset.get_prompt_and_completion(j)
|
||||
prompt_length, completion_length = input_and_output_lengths(
|
||||
prompt, completion, tokenizer
|
||||
)
|
||||
|
||||
prompt_lengths.append(prompt_length)
|
||||
completion_lengths.append(completion_length)
|
||||
|
||||
prompt_lengths.append(input_length(prompt, completion, tokenizer))
|
||||
full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||
if full_sequence[-1] != tokenizer.eos_token_id:
|
||||
full_sequence.append(tokenizer.eos_token_id)
|
||||
|
Loading…
Reference in New Issue
Block a user