This commit is contained in:
Awni Hannun 2025-02-09 19:37:11 -08:00
parent bb2c8bcf96
commit eda597bdef

View File

@ -1,5 +1,6 @@
import itertools import itertools
import json import json
import types
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -115,22 +116,26 @@ class ConcatenatedDataset:
def create_dataset( def create_dataset(
data, data,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
config: Dict, config,
): ):
mask_prompt = getattr(config, "mask_prompt", False) mask_prompt = getattr(config, "mask_prompt", False)
prompt_feature = getattr(config, "prompt_feature", "prompt") prompt_feature = getattr(config, "prompt_feature", "prompt")
text_feature = getattr(config, "text_feature", "text")
completion_feature = getattr(config, "completion_feature", "completion") completion_feature = getattr(config, "completion_feature", "completion")
chat_feature = getattr(config, "chat_feature", "messages")
sample = data[0] sample = data[0]
if "messages" in sample: if prompt_feature in sample and completion_feature in sample:
return ChatDataset(data, tokenizer, mask_prompt=mask_prompt)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset( return CompletionsDataset(
data, tokenizer, prompt_feature, completion_feature, mask_prompt data, tokenizer, prompt_feature, completion_feature, mask_prompt
) )
elif "text" in sample: elif chat_feature in sample:
return ChatDataset(
data, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
)
elif text_feature in sample:
if mask_prompt: if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.") raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(data, tokenizer) return Dataset(data, tokenizer, text_key=text_feature)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
@ -141,7 +146,7 @@ def create_dataset(
def load_local_dataset( def load_local_dataset(
data_path: Path, data_path: Path,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
config: Dict, config,
): ):
def load_subset(path): def load_subset(path):
if not path.exists(): if not path.exists():
@ -158,7 +163,7 @@ def load_local_dataset(
def load_hf_dataset( def load_hf_dataset(
data_id: str, data_id: str,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
config: Dict, config,
): ):
from datasets import exceptions, load_dataset from datasets import exceptions, load_dataset
@ -185,39 +190,13 @@ def load_hf_dataset(
def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
import datasets import datasets
mask_prompt = getattr(args, "mask_prompt", False) def create_hf_dataset(dataset_name, config, split, hf_config):
def create_hf_dataset(
dataset_name,
text_feature,
prompt_feature,
completion_feature,
chat_feature,
split,
config,
):
ds = datasets.load_dataset( ds = datasets.load_dataset(
dataset_name, dataset_name,
split=split, split=split,
**config, **hf_config,
) )
if prompt_feature and completion_feature: return create_dataset(ds, tokenizer, config)
return CompletionsDataset(
ds, tokenizer, prompt_feature, completion_feature, mask_prompt
)
elif chat_feature:
return ChatDataset(
ds, tokenizer, chat_key=chat_feature, mask_prompt=mask_prompt
)
elif text_feature:
if mask_prompt:
raise ValueError("Prompt masking not supported for text dataset.")
return Dataset(ds, tokenizer, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature, a chat feature,"
" or a text feature for the Hugging Face dataset."
)
dataset_collection = args.hf_dataset dataset_collection = args.hf_dataset
if isinstance(dataset_collection, dict): if isinstance(dataset_collection, dict):
@ -227,31 +206,23 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
for ds in dataset_collection: for ds in dataset_collection:
ds_name = ds["name"] ds_name = ds["name"]
print(f"Loading Hugging Face dataset {ds_name}.") print(f"Loading Hugging Face dataset {ds_name}.")
text_f = ds.get("text_feature", None) ds["mask_prompt"] = getattr(args, "mask_prompt", False)
prompt_f = ds.get("prompt_feature", None) config = types.SimpleNamespace(**ds)
completion_f = ds.get("completion_feature", None) hf_config = ds.get("config", {})
chat_f = ds.get("chat_feature", None)
ds_config = ds.get("config", {})
if args.train: if args.train:
train_split = ds.get("train_split", "train[:80%]") train_split = ds.get("train_split", "train[:80%]")
valid_split = ds.get("valid_split", "train[-10%:]") valid_split = ds.get("valid_split", "train[-10%:]")
train = create_hf_dataset( train = create_hf_dataset(
ds_name, ds_name,
text_f, config,
prompt_f,
completion_f,
chat_f,
train_split, train_split,
ds_config, hf_config,
) )
valid = create_hf_dataset( valid = create_hf_dataset(
ds_name, ds_name,
text_f, config,
prompt_f,
completion_f,
chat_f,
valid_split, valid_split,
ds_config, hf_config,
) )
else: else:
train, valid = [], [] train, valid = [], []
@ -260,12 +231,9 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
test_split = ds.get("test_split") test_split = ds.get("test_split")
test = create_hf_dataset( test = create_hf_dataset(
ds_name, ds_name,
text_f, config,
prompt_f,
completion_f,
chat_f,
test_split, test_split,
ds_config, hf_config,
) )
else: else:
test = [] test = []