mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:12:24 +08:00
Replace iterate_input_masked_batches with iterate_delineated_batches, an updated attempt to better sync with iterate_batches logic
This commit is contained in:
parent
604be3cec9
commit
79a042768f
@ -74,6 +74,9 @@ class CompletionsDataset:
|
||||
for d in data
|
||||
]
|
||||
|
||||
def get_prompt_and_completion(self, idx: int):
|
||||
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]
|
||||
|
||||
|
@ -5,13 +5,16 @@ import shutil
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Union
|
||||
from typing import List, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
from .datasets import CompletionsDataset
|
||||
|
||||
|
||||
def grad_checkpoint(layer):
|
||||
@ -63,47 +66,6 @@ class TrainingArgs:
|
||||
)
|
||||
|
||||
|
||||
def iterate_input_masked_batches(
|
||||
input_text, output_text, tokenizer, max_seq_length=2048
|
||||
):
|
||||
batch_size = len(input_text)
|
||||
|
||||
input_batch = [tokenizer.encode(record) for record in input_text]
|
||||
output_batch = [tokenizer.encode(record) for record in output_text]
|
||||
|
||||
input_lengths = [len(x) for x in input_batch]
|
||||
output_lengths = [len(x) for x in output_batch]
|
||||
|
||||
full_labels = [input_batch[idx] + output_batch[idx] for idx in range(batch_size)]
|
||||
lengths = [len(x) for x in full_labels]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
|
||||
"Consider pre-splitting your data to save memory."
|
||||
)
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
||||
adjusted_lengths = []
|
||||
for j in range(batch_size):
|
||||
input_length = input_lengths[j]
|
||||
full_ids_end_idx = input_length + min(
|
||||
output_lengths[j], max_length_in_batch - input_length
|
||||
)
|
||||
adjusted_lengths.append(full_ids_end_idx)
|
||||
batch_arr[j, :full_ids_end_idx] = full_labels[j][:full_ids_end_idx]
|
||||
|
||||
batch = mx.array(batch_arr)
|
||||
input_lengths = mx.array(input_lengths)
|
||||
lengths = mx.array(adjusted_lengths)
|
||||
|
||||
return batch[:, :-1], input_lengths, lengths
|
||||
|
||||
|
||||
def input_masked_loss(model, inputs, input_lengths, lengths):
|
||||
shifted_inputs = inputs[:, :-1]
|
||||
shifted_labels = inputs[:, 1:]
|
||||
@ -135,6 +97,117 @@ def default_loss(model, inputs, targets, lengths):
|
||||
return ce, ntoks
|
||||
|
||||
|
||||
def contains(small_list: List, big_list: List) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the beginning and end index of the first occurrence of small_list in big_list.
|
||||
"""
|
||||
for i in range(len(big_list) - len(small_list) + 1):
|
||||
for j in range(len(small_list)):
|
||||
if big_list[i + j] != small_list[j]:
|
||||
break
|
||||
else:
|
||||
return i, i + len(small_list)
|
||||
raise RuntimeError("Not found")
|
||||
|
||||
|
||||
def no_bos(sequence: List, bos: int) -> List:
|
||||
return sequence if sequence[0] != bos else sequence[1:]
|
||||
|
||||
|
||||
def input_and_output_lengths(
|
||||
input_text: str, output_text: str, tokenizer: PreTrainedTokenizer
|
||||
) -> Tuple[int, int]:
|
||||
"""
|
||||
Returns the length of the portion of the encoding of the concatenation of input_text and output_text
|
||||
that corresponds to the input tokens and the length of the portion that corresponds to the output tokens.
|
||||
"""
|
||||
message = [
|
||||
{"role": "user", "content": input_text},
|
||||
{"role": "assistant", "content": output_text},
|
||||
]
|
||||
output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id)
|
||||
full_sequence = tokenizer.apply_chat_template(message, tokenize=True)
|
||||
output_begin, output_end = contains(output_tokens, full_sequence)
|
||||
return output_begin, len(full_sequence) - len(output_tokens) + 1
|
||||
|
||||
|
||||
def iterate_delineated_batches(
|
||||
dataset: CompletionsDataset,
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
batch_size: int,
|
||||
max_seq_length: int,
|
||||
train: bool = False,
|
||||
):
|
||||
"""
|
||||
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
|
||||
(using create_delineated_batches), and returns the lengths of input tokens as well as the full sequences.
|
||||
"""
|
||||
idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i]))
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size}"
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
# If running in distributed mode (N machines) then each one should skip N-1
|
||||
# samples
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
# Make the batches:
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx))
|
||||
for i in indices:
|
||||
prompt_lengths = []
|
||||
completion_lengths = []
|
||||
batch = []
|
||||
for j in batch_idx[i]:
|
||||
prompt, completion = dataset.get_prompt_and_completion(j)
|
||||
prompt_length, completion_length = input_and_output_lengths(
|
||||
prompt, prompt, tokenizer
|
||||
)
|
||||
|
||||
prompt_lengths.append(prompt_length)
|
||||
completion_lengths.append(completion_length)
|
||||
|
||||
full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||
if full_sequence[-1] != tokenizer.eos_token_id:
|
||||
full_sequence.append(tokenizer.eos_token_id)
|
||||
batch.append(full_sequence)
|
||||
|
||||
lengths = [len(x) for x in batch]
|
||||
|
||||
if max(lengths) > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
|
||||
"Consider pre-splitting your data to save memory."
|
||||
)
|
||||
|
||||
# Pad to the nearest multiple of 8 or the maximum length
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
|
||||
for j in range(batch_size // step):
|
||||
truncated_length = min(lengths[j], max_seq_length)
|
||||
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||
lengths[j] = (
|
||||
truncated_length # Update lengths to match truncated lengths
|
||||
)
|
||||
|
||||
yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths)
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
# Sort by length:
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||
|
Loading…
Reference in New Issue
Block a user