From 14a75f3f03420fd7ea6886170e19721fdf16c7c7 Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Sun, 3 Nov 2024 19:11:54 -0500 Subject: [PATCH] Generalize HF datasets to a collection of HF dataasets via `datasets`, adds support for custom chat HF datasets (#1088), and fixes (#1087) --- llms/mlx_lm/tuner/datasets.py | 106 ++++++++++++++++++++++++++-------- 1 file changed, 82 insertions(+), 24 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 42d19a09..bef354c9 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -34,14 +34,15 @@ class ChatDataset: https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): + def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, chat_key: str = "messages"): self._data = [ tokenizer.apply_chat_template( - d["messages"], + d[chat_key], tools=d.get("tools", None), ) for d in data ] + self._chat_key = chat_key def __getitem__(self, idx: int): return self._data[idx] @@ -84,6 +85,29 @@ class CompletionsDataset: return len(self._data) +class CompletionsDatasetCollection: + def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): + self.collection = data + + def __getitem__(self, idx: int): + item = next(self.collection) + + curr_idx = idx + + while True: + try: + if (curr_idx + 1) < len(item): + return item[curr_idx] + else: + curr_idx -= len(item) + item = next(self.collection) + except StopIteration: + raise IndexError(idx) + + def __len__(self): + return sum(map(len, self.collection)) + + def create_dataset( data, tokenizer: PreTrainedTokenizer, @@ -157,14 +181,14 @@ def load_hf_dataset( def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): import datasets - hf_args = args.hf_dataset - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - text_feature = hf_args.get("text_feature") - prompt_feature = hf_args.get("prompt_feature") - completion_feature = hf_args.get("completion_feature") - - def create_hf_dataset(split: str = None): + def create_hf_dataset( + dataset_name, + text_feature, + prompt_feature, + completion_feature, + chat_feature, + split, + ): ds = datasets.load_dataset( dataset_name, split=split, @@ -172,27 +196,61 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): ) if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) + elif chat_feature: + return ChatDataset(ds, tokenizer, chat_key=chat_feature) elif text_feature: return Dataset(ds, tokenizer, text_key=text_feature) else: raise ValueError( - "Specify either a prompt and completion feature or a text " - "feature for the Hugging Face dataset." + "Specify either a prompt and completion feature, a chat feature," + " or a text feature for the Hugging Face dataset." ) - if args.train: - train_split = hf_args.get("train_split", "train[:80%]") - valid_split = hf_args.get("valid_split", "train[-10%:]") - train = create_hf_dataset(split=train_split) - valid = create_hf_dataset(split=valid_split) - else: - train, valid = [], [] - if args.test: - test = create_hf_dataset(split=hf_args.get("test_split")) - else: - test = [] + def get_train_and_valid_splits(hf_args, ds_name): + text_f = hf_args.get("text_feature", None) + prompt_f = hf_args.get("prompt_feature", None) + completion_f = hf_args.get("completion_feature", None) + chat_f = hf_args.get("chat_feature", None) - return train, valid, test + if args.train: + train_split = hf_args.get("train_split", "train[:80%]") + valid_split = hf_args.get("valid_split", "train[-10%:]") + train = create_hf_dataset( + ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split + ) + valid = create_hf_dataset( + ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split + ) + else: + train, valid = [], [] + + if args.test: + test_split = hf_args.get("test_split") + test = create_hf_dataset( + ds_name, text_f, prompt_f, completion_f, chat_f, split=test_split, + ) + else: + test = [] + + return train, valid, test + + if args.datasets: + dataset_collection = args.hf_datasets + else: + dataset_collection = {"hf_dataset": args.hf_dataset} + + datasets = [] + for ds in dataset_collection: + hf_args = ds["hf_dataset"] + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + datasets.append(get_splits(hf_args, dataset_name)) + if len(datsets) == 1: + return *datasets + + # Otherwise concatenate them + train, valid, test = zip(*datasets) + return tuple(map, Concatenate, zip(*datasets)) def load_dataset(args, tokenizer: PreTrainedTokenizer):