diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 9a251011..7d1821d4 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -101,34 +101,33 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]: """ Returns the beginning and end index of the first occurrence of small_list in big_list. """ - for i in range(len(big_list) - len(small_list) + 1): - for j in range(len(small_list)): - if big_list[i + j] != small_list[j]: - break - else: - return i, i + len(small_list) - raise RuntimeError("Not found") + small_list_length = len(small_list) + for ind in (i for i, e in enumerate(big_list) if e == small_list[0]): + if big_list[ind : ind + small_list_length] == small_list: + return ind, ind + small_list_length - 1 def no_bos(sequence: List, bos: int) -> List: return sequence if sequence[0] != bos else sequence[1:] -def input_and_output_lengths( +def input_length( input_text: str, output_text: str, tokenizer: PreTrainedTokenizer -) -> Tuple[int, int]: +) -> int: """ Returns the length of the portion of the encoding of the concatenation of input_text and output_text - that corresponds to the input tokens and the length of the portion that corresponds to the output tokens. + that corresponds to the input tokens. """ message = [ {"role": "user", "content": input_text}, {"role": "assistant", "content": output_text}, ] output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id) - full_sequence = tokenizer.apply_chat_template(message, tokenize=True) + full_sequence = tokenizer.apply_chat_template( + message, add_generation_prompt=True, tokenize=True + ) output_begin, output_end = contains(output_tokens, full_sequence) - return output_begin, len(full_sequence) - len(output_tokens) + 1 + return output_begin def iterate_delineated_batches( @@ -163,17 +162,10 @@ def iterate_delineated_batches( indices = np.random.permutation(len(batch_idx)) for i in indices: prompt_lengths = [] - completion_lengths = [] batch = [] for j in batch_idx[i]: prompt, completion = dataset.get_prompt_and_completion(j) - prompt_length, completion_length = input_and_output_lengths( - prompt, completion, tokenizer - ) - - prompt_lengths.append(prompt_length) - completion_lengths.append(completion_length) - + prompt_lengths.append(input_length(prompt, completion, tokenizer)) full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] if full_sequence[-1] != tokenizer.eos_token_id: full_sequence.append(tokenizer.eos_token_id)