diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index 99cab169..1b202961 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -66,7 +66,7 @@ class TrainingArgs: ) -def input_masked_loss(model, inputs, input_lengths, lengths): +def input_masked_loss(model, inputs, response_prefix_lengths, lengths): shifted_inputs = inputs[:, :-1] shifted_labels = inputs[:, 1:] logits = model(shifted_inputs) @@ -75,7 +75,8 @@ def input_masked_loss(model, inputs, input_lengths, lengths): mask_width = shifted_inputs.shape[1] token_indices = mx.arange(mask_width)[None, :] mask = mx.logical_and( - token_indices >= input_lengths[:, None], token_indices < lengths[:, None] + token_indices >= response_prefix_lengths[:, None], + token_indices < lengths[:, None], ) ce = nn.losses.cross_entropy(logits, shifted_labels) * mask @@ -107,29 +108,6 @@ def contains(small_list: List, big_list: List) -> Tuple[int, int]: return ind, ind + small_list_length - 1 -def no_bos(sequence: List, bos: int) -> List: - return sequence if sequence[0] != bos else sequence[1:] - - -def input_length( - input_text: str, output_text: str, tokenizer: PreTrainedTokenizer -) -> int: - """ - Returns the length of the portion of the encoding of the concatenation of input_text and output_text - that corresponds to the input tokens. - """ - message = [ - {"role": "user", "content": input_text}, - {"role": "assistant", "content": output_text}, - ] - output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id) - full_sequence = tokenizer.apply_chat_template( - message, add_generation_prompt=True, tokenize=True - ) - output_begin, output_end = contains(output_tokens, full_sequence) - return output_begin - - def iterate_completion_batches( dataset: CompletionsDataset, tokenizer: PreTrainedTokenizer, @@ -161,15 +139,23 @@ def iterate_completion_batches( while True: indices = np.random.permutation(len(batch_idx)) for i in indices: - prompt_lengths = [] + response_prefix_lengths = [] batch = [] for j in batch_idx[i]: - prompt, completion = dataset.get_prompt_and_completion(j) - prompt_lengths.append(input_length(prompt, completion, tokenizer)) - full_sequence = tokenizer.encode(dataset[j]) + full_sequence = dataset.get_item(j, tokenize=True) if full_sequence[-1] != tokenizer.eos_token_id: full_sequence.append(tokenizer.eos_token_id) batch.append(full_sequence) + if len(dataset.response_token_ids) > 1: + response_marker_begin, response_marker_end = contains( + dataset.response_token_ids, full_sequence + ) + response_prefix_lengths.append(response_marker_end + 1) + else: + response_marker_begin = full_sequence.index( + dataset.response_token_ids[0] + ) + response_prefix_lengths.append(response_marker_begin + 1) lengths = [len(x) for x in batch] @@ -186,15 +172,19 @@ def iterate_completion_batches( max_length_in_batch = min(max_length_in_batch, max_seq_length) batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - for j in range(batch_size // step): + response_prefix_length = response_prefix_lengths[j] truncated_length = min(lengths[j], max_seq_length) - batch_arr[j, :truncated_length] = batch[j][:truncated_length] + batch_arr[j, response_prefix_length:truncated_length] = batch[j][ + response_prefix_length:truncated_length + ] lengths[j] = ( truncated_length # Update lengths to match truncated lengths ) - yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths) + yield mx.array(batch_arr), mx.array(response_prefix_lengths), mx.array( + lengths + ) if not train: break