From 9df7bbbe3a252af2fa0b538821050dc5418c4312 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 19:11:54 -0500 Subject: [PATCH] Generalize HF datasets to a collection of HF dataasets via `datasets`, adds support for custom chat HF datasets (#1088), and fixes (#1087) --- llms/mlx_lm/tuner/datasets.py | 137 +++++++++++++++++++++++++++++----- 1 file changed, 117 insertions(+), 20 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 20b32eff..a0c16a28 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,6 @@ import json from pathlib import Path -from typing import Dict, List +from typing import Dict, List, Union from transformers import PreTrainedTokenizer @@ -29,12 +29,18 @@ class ChatDataset(Dataset): https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + chat_key: str = "messages", + ): super().__init__(data) self._tokenizer = tokenizer + self._chat_key = chat_key def __getitem__(self, idx: int): - messages = self._data[idx]["messages"] + messages = self._data[idx][self._chat_key] text = self._tokenizer.apply_chat_template( messages, tools=self._data[idx].get("tools", None), @@ -76,6 +82,29 @@ class CompletionsDataset(Dataset): return text +class CompletionsDatasetCollection: + def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): + self.collection = data + + def __getitem__(self, idx: int): + item = next(self.collection) + + curr_idx = idx + + while True: + try: + if (curr_idx + 1) < len(item): + return item[curr_idx] + else: + curr_idx -= len(item) + item = next(self.collection) + except StopIteration: + raise IndexError(idx) + + def __len__(self): + return sum(map(len, self.collection)) + + def create_dataset(data, tokenizer: PreTrainedTokenizer = None): sample = data[0] @@ -127,14 +156,14 @@ def load_hf_dataset(data_id: str, tokenizer: PreTrainedTokenizer): def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") - - def create_hf_dataset(split: str = None): + def create_hf_dataset( + dataset_name: Union[None, str], + text_feature: Union[None, str], + prompt_feature: Union[None, str], + completion_feature: Union[None, str], + chat_feature: Union[None, str], + split: str = None, + ): ds = datasets.load_dataset( dataset_name, split=split, @@ -142,25 +171,93 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): ) if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) + elif chat_feature: + return ChatDataset(ds, tokenizer, chat_key=chat_feature) elif text_feature: - return Dataset(train_ds, text_key=text_feature) + return Dataset(ds, text_key=text_feature) else: raise ValueError( "Specify either a prompt and completion feature or a text " "feature for the Hugging Face dataset." ) - if args.train: + def get_hf_custom_features(hf_args): + return ( + hf_args.get("text_feature"), + hf_args.get("prompt_feature"), + hf_args.get("completion_feature"), + hf_args.get("chat_feature"), + ) + + def get_train_and_valid_splits(hf_args, ds_name): train_split = hf_args.get("train_split", "train[:80%]") valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) + text_f, prompt_f, completion_f, chat_f = get_hf_custom_features(hf_args) + train = create_hf_dataset( + ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split + ) + valid = create_hf_dataset( + ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split + ) + return train, valid + + if args.datasets: + dataset_collection = args.hf_datasets + train_collection = [] + valid_collection = [] + test_collection = [] + for ds in dataset_collection: + hf_args = ds["hf_dataset"] + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature, prompt_feature, completion_feature, chat_f = ( + get_hf_custom_features(hf_args) + ) + if args.train: + train, valid = get_train_and_valid_splits(hf_args, dataset_name) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset( + dataset_name, + text_feature, + prompt_feature, + completion_feature, + chat_f, + split=hf_args.get("test_split"), + ) + else: + test = [] + train_collection.append(train) + valid_collection.append(valid) + test_collection.append(test) + return ( + CompletionsDatasetCollection(train_collection), + CompletionsDatasetCollection(valid_collection), + CompletionsDatasetCollection(test_collection), + ) else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) - else: - test = [] + hf_args = args.hf_dataset + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature, prompt_feature, completion_feature, chat_feature = ( + get_hf_custom_features(hf_args) + ) + if args.train: + train, valid = get_train_and_valid_splits(hf_args, dataset_name) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset( + dataset_name, + text_feature, + prompt_feature, + completion_feature, + chat_feature, + split=hf_args.get("test_split"), + ) + else: + test = [] return train, valid, test