Generalize HF datasets to a collection of HF dataasets via datasets, adds support for custom chat HF datasets (#1088), and fixes (#1087)

This commit is contained in:
Chime Ogbuji 2024-11-03 19:11:54 -05:00
parent 331148d8ec
commit 9df7bbbe3a

View File

@ -1,6 +1,6 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -29,12 +29,18 @@ class ChatDataset(Dataset):
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
chat_key: str = "messages",
):
super().__init__(data) super().__init__(data)
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._chat_key = chat_key
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
messages = self._data[idx]["messages"] messages = self._data[idx][self._chat_key]
text = self._tokenizer.apply_chat_template( text = self._tokenizer.apply_chat_template(
messages, messages,
tools=self._data[idx].get("tools", None), tools=self._data[idx].get("tools", None),
@ -76,6 +82,29 @@ class CompletionsDataset(Dataset):
return text return text
class CompletionsDatasetCollection:
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]):
self.collection = data
def __getitem__(self, idx: int):
item = next(self.collection)
curr_idx = idx
while True:
try:
if (curr_idx + 1) < len(item):
return item[curr_idx]
else:
curr_idx -= len(item)
item = next(self.collection)
except StopIteration:
raise IndexError(idx)
def __len__(self):
return sum(map(len, self.collection))
def create_dataset(data, tokenizer: PreTrainedTokenizer = None): def create_dataset(data, tokenizer: PreTrainedTokenizer = None):
sample = data[0] sample = data[0]
@ -127,14 +156,14 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer):
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets import datasets
hf_args = args.hf_dataset def create_hf_dataset(
dataset_name = hf_args["name"] dataset_name: Union[None, str],
print(f"Loading Hugging Face dataset {dataset_name}.") text_feature: Union[None, str],
text_feature = hf_args.get("text_feature") prompt_feature: Union[None, str],
prompt_feature = hf_args.get("prompt_feature") completion_feature: Union[None, str],
completion_feature = hf_args.get("completion_feature") chat_feature: Union[None, str],
split: str = None,
def create_hf_dataset(split: str = None): ):
ds = datasets.load_dataset( ds = datasets.load_dataset(
dataset_name, dataset_name,
split=split, split=split,
@ -142,25 +171,93 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
) )
if prompt_feature and completion_feature: if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
elif chat_feature:
return ChatDataset(ds, tokenizer, chat_key=chat_feature)
elif text_feature: elif text_feature:
return Dataset(train_ds, text_key=text_feature) return Dataset(ds, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Specify either a prompt and completion feature or a text " "Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset." "feature for the Hugging Face dataset."
) )
if args.train: def get_hf_custom_features(hf_args):
return (
hf_args.get("text_feature"),
hf_args.get("prompt_feature"),
hf_args.get("completion_feature"),
hf_args.get("chat_feature"),
)
def get_train_and_valid_splits(hf_args, ds_name):
train_split = hf_args.get("train_split", "train[:80%]") train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]") valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split) text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args)
valid = create_hf_dataset(split=valid_split) train = create_hf_dataset(
ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split
)
valid = create_hf_dataset(
ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split
)
return train, valid
if args.datasets:
dataset_collection = args.hf_datasets
train_collection = []
valid_collection = []
test_collection = []
for ds in dataset_collection:
hf_args = ds["hf_dataset"]
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature, prompt_feature, completion_feature, chat_f = (
get_hf_custom_features(hf_args)
)
if args.train:
train, valid = get_train_and_valid_splits(hf_args, dataset_name)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(
dataset_name,
text_feature,
prompt_feature,
completion_feature,
chat_f,
split=hf_args.get("test_split"),
)
else:
test = []
train_collection.append(train)
valid_collection.append(valid)
test_collection.append(test)
return (
CompletionsDatasetCollection(train_collection),
CompletionsDatasetCollection(valid_collection),
CompletionsDatasetCollection(test_collection),
)
else: else:
train, valid = [], [] hf_args = args.hf_dataset
if args.test: dataset_name = hf_args["name"]
test = create_hf_dataset(split=hf_args.get("test_split")) print(f"Loading Hugging Face dataset {dataset_name}.")
else: text_feature, prompt_feature, completion_feature, chat_feature = (
test = [] get_hf_custom_features(hf_args)
)
if args.train:
train, valid = get_train_and_valid_splits(hf_args, dataset_name)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(
dataset_name,
text_feature,
prompt_feature,
completion_feature,
chat_feature,
split=hf_args.get("test_split"),
)
else:
test = []
return train, valid, test return train, valid, test