diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 377e7cae..42d19a09 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -74,6 +74,9 @@ class CompletionsDataset: for d in data ] + def get_prompt_and_completion(self, idx: int): + return self._data[idx][self._prompt_key], self._data[idx][self._completion_key] + def __getitem__(self, idx: int): return self._data[idx] diff --git a/llms/mlx_lm/tuner/trainer.py b/llms/mlx_lm/tuner/trainer.py index e8e6ea6a..75580a0e 100644 --- a/llms/mlx_lm/tuner/trainer.py +++ b/llms/mlx_lm/tuner/trainer.py @@ -5,13 +5,16 @@ import shutil import time from dataclasses import dataclass, field from pathlib import Path -from typing import Union +from typing import List, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.nn.utils import average_gradients from mlx.utils import tree_flatten +from transformers import PreTrainedTokenizer + +from .datasets import CompletionsDataset def grad_checkpoint(layer): @@ -63,47 +66,6 @@ class TrainingArgs: ) -def iterate_input_masked_batches( - input_text, output_text, tokenizer, max_seq_length=2048 -): - batch_size = len(input_text) - - input_batch = [tokenizer.encode(record) for record in input_text] - output_batch = [tokenizer.encode(record) for record in output_text] - - input_lengths = [len(x) for x in input_batch] - output_lengths = [len(x) for x in output_batch] - - full_labels = [input_batch[idx] + output_batch[idx] for idx in range(batch_size)] - lengths = [len(x) for x in full_labels] - - if max(lengths) > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " - "Consider pre-splitting your data to save memory." - ) - pad_to = 8 - max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32) - adjusted_lengths = [] - for j in range(batch_size): - input_length = input_lengths[j] - full_ids_end_idx = input_length + min( - output_lengths[j], max_length_in_batch - input_length - ) - adjusted_lengths.append(full_ids_end_idx) - batch_arr[j, :full_ids_end_idx] = full_labels[j][:full_ids_end_idx] - - batch = mx.array(batch_arr) - input_lengths = mx.array(input_lengths) - lengths = mx.array(adjusted_lengths) - - return batch[:, :-1], input_lengths, lengths - - def input_masked_loss(model, inputs, input_lengths, lengths): shifted_inputs = inputs[:, :-1] shifted_labels = inputs[:, 1:] @@ -135,6 +97,117 @@ def default_loss(model, inputs, targets, lengths): return ce, ntoks +def contains(small_list: List, big_list: List) -> Tuple[int, int]: + """ + Returns the beginning and end index of the first occurrence of small_list in big_list. + """ + for i in range(len(big_list) - len(small_list) + 1): + for j in range(len(small_list)): + if big_list[i + j] != small_list[j]: + break + else: + return i, i + len(small_list) + raise RuntimeError("Not found") + + +def no_bos(sequence: List, bos: int) -> List: + return sequence if sequence[0] != bos else sequence[1:] + + +def input_and_output_lengths( + input_text: str, output_text: str, tokenizer: PreTrainedTokenizer +) -> Tuple[int, int]: + """ + Returns the length of the portion of the encoding of the concatenation of input_text and output_text + that corresponds to the input tokens and the length of the portion that corresponds to the output tokens. + """ + message = [ + {"role": "user", "content": input_text}, + {"role": "assistant", "content": output_text}, + ] + output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id) + full_sequence = tokenizer.apply_chat_template(message, tokenize=True) + output_begin, output_end = contains(output_tokens, full_sequence) + return output_begin, len(full_sequence) - len(output_tokens) + 1 + + +def iterate_delineated_batches( + dataset: CompletionsDataset, + tokenizer: PreTrainedTokenizer, + batch_size: int, + max_seq_length: int, + train: bool = False, +): + """ + A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens + (using create_delineated_batches), and returns the lengths of input tokens as well as the full sequences. + """ + idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i])) + if len(dataset) < batch_size: + raise ValueError( + f"Dataset must have at least batch_size={batch_size}" + f" examples but only has {len(dataset)}." + ) + + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + # Make the batches: + batch_idx = [ + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) + ] + while True: + indices = np.random.permutation(len(batch_idx)) + for i in indices: + prompt_lengths = [] + completion_lengths = [] + batch = [] + for j in batch_idx[i]: + prompt, completion = dataset.get_prompt_and_completion(j) + prompt_length, completion_length = input_and_output_lengths( + prompt, prompt, tokenizer + ) + + prompt_lengths.append(prompt_length) + completion_lengths.append(completion_length) + + full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]] + if full_sequence[-1] != tokenizer.eos_token_id: + full_sequence.append(tokenizer.eos_token_id) + batch.append(full_sequence) + + lengths = [len(x) for x in batch] + + if max(lengths) > max_seq_length: + print( + f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " + f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " + "Consider pre-splitting your data to save memory." + ) + + # Pad to the nearest multiple of 8 or the maximum length + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) + + for j in range(batch_size // step): + truncated_length = min(lengths[j], max_seq_length) + batch_arr[j, :truncated_length] = batch[j][:truncated_length] + lengths[j] = ( + truncated_length # Update lengths to match truncated lengths + ) + + yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths) + + if not train: + break + + def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): # Sort by length: idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))