simplify collections

This commit is contained in:
Awni Hannun 2025-02-09 08:32:18 -08:00
parent b9748e9ee4
commit 6ace6dc6b2
4 changed files with 81 additions and 86 deletions

View File

@ -299,7 +299,7 @@ it on the command line. For example, pass `--data mlx-community/wikisql` to
train on the pre-formatted WikiwSQL data. train on the pre-formatted WikiwSQL data.
Otherwise, provide a mapping of keys in the dataset to the features MLX LM Otherwise, provide a mapping of keys in the dataset to the features MLX LM
expects. Use a YAML config to specify the Hugging Face (HF) dataset arguments. For expects. Use a YAML config to specify the Hugging Face dataset arguments. For
example: example:
```yaml ```yaml
@ -316,19 +316,17 @@ hf_dataset:
- To specify the train, valid, or test splits, set the corresponding - To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument. `{train,valid,test}_split` argument.
You can specify a list of HF datasets using the `hf_datasets` (plural) configuration, which is a list of records You can specify a list of Hugging Face datasets with a list of records each
each with the same structure as above. For example: with the same structure as above. For example:
```yaml ```yaml
hf_datasets: hf_dataset:
- hf_dataset: - name: "Open-Orca/OpenOrca"
name: "Open-Orca/OpenOrca"
train_split: "train[:90%]" train_split: "train[:90%]"
valid_split: "train[-10%:]" valid_split: "train[-10%:]"
prompt_feature: "question" prompt_feature: "question"
completion_feature: "response" completion_feature: "response"
- hf_dataset: - name: "trl-lib/ultrafeedback_binarized"
name: "trl-lib/ultrafeedback_binarized"
train_split: "train[:90%]" train_split: "train[:90%]"
valid_split: "train[-10%:]" valid_split: "train[-10%:]"
chat_feature: "chosen" chat_feature: "chosen"

View File

@ -61,7 +61,6 @@ CONFIG_DEFAULTS = {
"config": None, "config": None,
"grad_checkpoint": False, "grad_checkpoint": False,
"lr_schedule": None, "lr_schedule": None,
"hf_datasets": None,
"lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0},
"response_template": None, "response_template": None,
} }

View File

@ -1,6 +1,7 @@
import itertools
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -34,7 +35,12 @@ class ChatDataset:
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, chat_key: str = "messages"): def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
chat_key: str = "messages",
):
self._data = [ self._data = [
tokenizer.apply_chat_template( tokenizer.apply_chat_template(
d[chat_key], d[chat_key],
@ -42,7 +48,6 @@ class ChatDataset:
) )
for d in data for d in data
] ]
self._chat_key = chat_key
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx] return self._data[idx]
@ -82,48 +87,15 @@ class CompletionsDataset:
return len(self._data) return len(self._data)
class CompletionsDatasetCollection: class ConcatenatedDataset:
def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): def __init__(self, data: List[Any]):
self.collection = data self._data = list(itertools.chain(*data))
def __fetch_and_process_item__(self, idx: int, handler_fn: Callable):
iteration = iter(self.collection)
item = next(iteration)
curr_idx = idx
while True:
try:
if (curr_idx + 1) <= len(item):
return handler_fn(item, curr_idx)
else:
curr_idx -= len(item)
item = next(iteration)
except StopIteration:
raise IndexError(idx)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int): return self._data[idx]
return dataset[index]
return self.__fetch_and_process_item__(idx, getitem)
def get_item(
self, idx: int, tokenize: bool = False, add_generation_prompt: bool = True
) -> str:
def getitem(dataset: CompletionsDataset, index: int):
return dataset.get_item(index, tokenize, add_generation_prompt)
return self.__fetch_and_process_item__(idx, getitem)
def get_prompt_and_completion(self, idx: int):
def getitem(dataset: CompletionsDataset, index: int):
return dataset.get_prompt_and_completion(index)
return self.__fetch_and_process_item__(idx, getitem)
def __len__(self): def __len__(self):
return sum(map(len, self.collection)) return len(self._data)
def create_dataset( def create_dataset(
@ -206,11 +178,12 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
completion_feature, completion_feature,
chat_feature, chat_feature,
split, split,
config,
): ):
ds = datasets.load_dataset( ds = datasets.load_dataset(
dataset_name, dataset_name,
split=split, split=split,
**hf_args.get("config", {}), **config,
) )
if prompt_feature and completion_feature: if prompt_feature and completion_feature:
return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature)
@ -224,54 +197,68 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer):
" or a text feature for the Hugging Face dataset." " or a text feature for the Hugging Face dataset."
) )
def get_train_and_valid_splits(hf_args, ds_name): dataset_collection = args.hf_dataset
text_f = hf_args.get("text_feature", None) if isinstance(dataset_collection, dict):
prompt_f = hf_args.get("prompt_feature", None) dataset_collection = [dataset_collection]
completion_f = hf_args.get("completion_feature", None)
chat_f = hf_args.get("chat_feature", None) collection = []
for ds in dataset_collection:
ds_name = ds["name"]
print(f"Loading Hugging Face dataset {ds_name}.")
text_f = ds.get("text_feature", None)
prompt_f = ds.get("prompt_feature", None)
completion_f = ds.get("completion_feature", None)
chat_f = ds.get("chat_feature", None)
ds_config = ds.get("config", {})
if args.train: if args.train:
train_split = hf_args.get("train_split", "train[:80%]") train_split = ds.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]") valid_split = ds.get("valid_split", "train[-10%:]")
train = create_hf_dataset( train = create_hf_dataset(
ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split ds_name,
text_f,
prompt_f,
completion_f,
chat_f,
train_split,
ds_config,
) )
valid = create_hf_dataset( valid = create_hf_dataset(
ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split ds_name,
text_f,
prompt_f,
completion_f,
chat_f,
valid_split,
ds_config,
) )
else: else:
train, valid = [], [] train, valid = [], []
if args.test: if args.test:
test_split = hf_args.get("test_split") test_split = ds.get("test_split")
test = create_hf_dataset( test = create_hf_dataset(
ds_name, text_f, prompt_f, completion_f, chat_f, split=test_split, ds_name,
text_f,
prompt_f,
completion_f,
chat_f,
test_split,
ds_config,
) )
else: else:
test = [] test = []
return train, valid, test collection.append((train, valid, test))
if args.datasets: if len(collection) == 1:
dataset_collection = args.hf_datasets return collection[0]
else:
dataset_collection = {"hf_dataset": args.hf_dataset}
datasets = []
for ds in dataset_collection:
hf_args = ds["hf_dataset"]
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
datasets.append(get_splits(hf_args, dataset_name))
if len(datsets) == 1:
return *datasets
# Otherwise concatenate them # Otherwise concatenate them
train, valid, test = zip(*datasets) return tuple(map(ConcatenatedDataset, zip(*collection)))
return tuple(map, Concatenate, zip(*datasets))
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
if getattr(args, "hf_dataset", False) or getattr(args, "hf_datasets", False): if getattr(args, "hf_dataset", False):
train, valid, test = load_custom_hf_dataset(args, tokenizer) train, valid, test = load_custom_hf_dataset(args, tokenizer)
else: else:
data_path = Path(args.data) data_path = Path(args.data)

View File

@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(isinstance(train, datasets.ChatDataset)) self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self): def test_hf(self):
hf_args = {
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
}
args = types.SimpleNamespace( args = types.SimpleNamespace(
hf_dataset={ hf_dataset=hf_args,
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
"train_split": "train[:2%]",
"valid_split": "train[-2%:]",
},
test=False, test=False,
train=True, train=True,
) )
@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0) self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0) self.assertEqual(len(test), 0)
args = types.SimpleNamespace(
hf_dataset=[hf_args, hf_args],
test=False,
train=True,
)
train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer)
self.assertEqual(2 * len(train), len(train_double))
self.assertEqual(2 * len(valid), len(valid_double))
self.assertEqual(2 * len(test), len(test_double))
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()