mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-29 01:12:24 +08:00
Replace iterate_input_masked_batches with iterate_delineated_batches, an updated attempt to better sync with iterate_batches logic
This commit is contained in:
parent
604be3cec9
commit
79a042768f
@ -74,6 +74,9 @@ class CompletionsDataset:
|
|||||||
for d in data
|
for d in data
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def get_prompt_and_completion(self, idx: int):
|
||||||
|
return self._data[idx][self._prompt_key], self._data[idx][self._completion_key]
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]
|
return self._data[idx]
|
||||||
|
|
||||||
|
@ -5,13 +5,16 @@ import shutil
|
|||||||
import time
|
import time
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union
|
from typing import List, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from mlx.nn.utils import average_gradients
|
from mlx.nn.utils import average_gradients
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
from .datasets import CompletionsDataset
|
||||||
|
|
||||||
|
|
||||||
def grad_checkpoint(layer):
|
def grad_checkpoint(layer):
|
||||||
@ -63,47 +66,6 @@ class TrainingArgs:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def iterate_input_masked_batches(
|
|
||||||
input_text, output_text, tokenizer, max_seq_length=2048
|
|
||||||
):
|
|
||||||
batch_size = len(input_text)
|
|
||||||
|
|
||||||
input_batch = [tokenizer.encode(record) for record in input_text]
|
|
||||||
output_batch = [tokenizer.encode(record) for record in output_text]
|
|
||||||
|
|
||||||
input_lengths = [len(x) for x in input_batch]
|
|
||||||
output_lengths = [len(x) for x in output_batch]
|
|
||||||
|
|
||||||
full_labels = [input_batch[idx] + output_batch[idx] for idx in range(batch_size)]
|
|
||||||
lengths = [len(x) for x in full_labels]
|
|
||||||
|
|
||||||
if max(lengths) > max_seq_length:
|
|
||||||
print(
|
|
||||||
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
|
||||||
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
|
|
||||||
"Consider pre-splitting your data to save memory."
|
|
||||||
)
|
|
||||||
pad_to = 8
|
|
||||||
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
|
||||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
|
||||||
|
|
||||||
batch_arr = np.zeros((batch_size, max_length_in_batch), np.int32)
|
|
||||||
adjusted_lengths = []
|
|
||||||
for j in range(batch_size):
|
|
||||||
input_length = input_lengths[j]
|
|
||||||
full_ids_end_idx = input_length + min(
|
|
||||||
output_lengths[j], max_length_in_batch - input_length
|
|
||||||
)
|
|
||||||
adjusted_lengths.append(full_ids_end_idx)
|
|
||||||
batch_arr[j, :full_ids_end_idx] = full_labels[j][:full_ids_end_idx]
|
|
||||||
|
|
||||||
batch = mx.array(batch_arr)
|
|
||||||
input_lengths = mx.array(input_lengths)
|
|
||||||
lengths = mx.array(adjusted_lengths)
|
|
||||||
|
|
||||||
return batch[:, :-1], input_lengths, lengths
|
|
||||||
|
|
||||||
|
|
||||||
def input_masked_loss(model, inputs, input_lengths, lengths):
|
def input_masked_loss(model, inputs, input_lengths, lengths):
|
||||||
shifted_inputs = inputs[:, :-1]
|
shifted_inputs = inputs[:, :-1]
|
||||||
shifted_labels = inputs[:, 1:]
|
shifted_labels = inputs[:, 1:]
|
||||||
@ -135,6 +97,117 @@ def default_loss(model, inputs, targets, lengths):
|
|||||||
return ce, ntoks
|
return ce, ntoks
|
||||||
|
|
||||||
|
|
||||||
|
def contains(small_list: List, big_list: List) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the beginning and end index of the first occurrence of small_list in big_list.
|
||||||
|
"""
|
||||||
|
for i in range(len(big_list) - len(small_list) + 1):
|
||||||
|
for j in range(len(small_list)):
|
||||||
|
if big_list[i + j] != small_list[j]:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
return i, i + len(small_list)
|
||||||
|
raise RuntimeError("Not found")
|
||||||
|
|
||||||
|
|
||||||
|
def no_bos(sequence: List, bos: int) -> List:
|
||||||
|
return sequence if sequence[0] != bos else sequence[1:]
|
||||||
|
|
||||||
|
|
||||||
|
def input_and_output_lengths(
|
||||||
|
input_text: str, output_text: str, tokenizer: PreTrainedTokenizer
|
||||||
|
) -> Tuple[int, int]:
|
||||||
|
"""
|
||||||
|
Returns the length of the portion of the encoding of the concatenation of input_text and output_text
|
||||||
|
that corresponds to the input tokens and the length of the portion that corresponds to the output tokens.
|
||||||
|
"""
|
||||||
|
message = [
|
||||||
|
{"role": "user", "content": input_text},
|
||||||
|
{"role": "assistant", "content": output_text},
|
||||||
|
]
|
||||||
|
output_tokens = no_bos(tokenizer.encode(output_text), tokenizer.bos_token_id)
|
||||||
|
full_sequence = tokenizer.apply_chat_template(message, tokenize=True)
|
||||||
|
output_begin, output_end = contains(output_tokens, full_sequence)
|
||||||
|
return output_begin, len(full_sequence) - len(output_tokens) + 1
|
||||||
|
|
||||||
|
|
||||||
|
def iterate_delineated_batches(
|
||||||
|
dataset: CompletionsDataset,
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
batch_size: int,
|
||||||
|
max_seq_length: int,
|
||||||
|
train: bool = False,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
A version of iterate_batches that works with completion datasets, tracks the boundaries between input/output tokens
|
||||||
|
(using create_delineated_batches), and returns the lengths of input tokens as well as the full sequences.
|
||||||
|
"""
|
||||||
|
idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i]))
|
||||||
|
if len(dataset) < batch_size:
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset must have at least batch_size={batch_size}"
|
||||||
|
f" examples but only has {len(dataset)}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# If running in distributed mode (N machines) then each one should skip N-1
|
||||||
|
# samples
|
||||||
|
step = mx.distributed.init().size()
|
||||||
|
if batch_size % step != 0:
|
||||||
|
raise ValueError("The batch size must be divisible by the number of workers")
|
||||||
|
# Make the batches:
|
||||||
|
batch_idx = [
|
||||||
|
idx[i : i + batch_size : step]
|
||||||
|
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||||
|
]
|
||||||
|
while True:
|
||||||
|
indices = np.random.permutation(len(batch_idx))
|
||||||
|
for i in indices:
|
||||||
|
prompt_lengths = []
|
||||||
|
completion_lengths = []
|
||||||
|
batch = []
|
||||||
|
for j in batch_idx[i]:
|
||||||
|
prompt, completion = dataset.get_prompt_and_completion(j)
|
||||||
|
prompt_length, completion_length = input_and_output_lengths(
|
||||||
|
prompt, prompt, tokenizer
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_lengths.append(prompt_length)
|
||||||
|
completion_lengths.append(completion_length)
|
||||||
|
|
||||||
|
full_sequence = [tokenizer.encode(dataset[j]) for j in batch_idx[i]]
|
||||||
|
if full_sequence[-1] != tokenizer.eos_token_id:
|
||||||
|
full_sequence.append(tokenizer.eos_token_id)
|
||||||
|
batch.append(full_sequence)
|
||||||
|
|
||||||
|
lengths = [len(x) for x in batch]
|
||||||
|
|
||||||
|
if max(lengths) > max_seq_length:
|
||||||
|
print(
|
||||||
|
f"[WARNING] Some sequences are longer than {max_seq_length} tokens. "
|
||||||
|
f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. "
|
||||||
|
"Consider pre-splitting your data to save memory."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Pad to the nearest multiple of 8 or the maximum length
|
||||||
|
pad_to = 8
|
||||||
|
max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to)
|
||||||
|
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||||
|
|
||||||
|
batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||||
|
|
||||||
|
for j in range(batch_size // step):
|
||||||
|
truncated_length = min(lengths[j], max_seq_length)
|
||||||
|
batch_arr[j, :truncated_length] = batch[j][:truncated_length]
|
||||||
|
lengths[j] = (
|
||||||
|
truncated_length # Update lengths to match truncated lengths
|
||||||
|
)
|
||||||
|
|
||||||
|
yield mx.array(batch_arr), mx.array(prompt_lengths), mx.array(lengths)
|
||||||
|
|
||||||
|
if not train:
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
# Sort by length:
|
# Sort by length:
|
||||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]))
|
||||||
|
Loading…
Reference in New Issue
Block a user