diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 1490c752..d51bce82 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -299,7 +299,7 @@ it on the command line. For example, pass `--data mlx-community/wikisql` to train on the pre-formatted WikiwSQL data. Otherwise, provide a mapping of keys in the dataset to the features MLX LM -expects. Use a YAML config to specify the Hugging Face (HF) dataset arguments. For +expects. Use a YAML config to specify the Hugging Face dataset arguments. For example: ```yaml @@ -316,19 +316,17 @@ hf_dataset: - To specify the train, valid, or test splits, set the corresponding `{train,valid,test}_split` argument. -You can specify a list of HF datasets using the `hf_datasets` (plural) configuration, which is a list of records -each with the same structure as above. For example: +You can specify a list of Hugging Face datasets with a list of records each +with the same structure as above. For example: ```yaml -hf_datasets: -- hf_dataset: - name: "Open-Orca/OpenOrca" +hf_dataset: + - name: "Open-Orca/OpenOrca" train_split: "train[:90%]" valid_split: "train[-10%:]" prompt_feature: "question" completion_feature: "response" -- hf_dataset: - name: "trl-lib/ultrafeedback_binarized" + - name: "trl-lib/ultrafeedback_binarized" train_split: "train[:90%]" valid_split: "train[-10%:]" chat_feature: "chosen" diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 192e5e30..5ac58aa1 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -61,7 +61,6 @@ CONFIG_DEFAULTS = { "config": None, "grad_checkpoint": False, "lr_schedule": None, - "hf_datasets": None, "lora_parameters": {"rank": 8, "alpha": 16, "dropout": 0.0, "scale": 10.0}, "response_template": None, } diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 2e617cf0..1f990fb7 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,6 +1,7 @@ +import itertools import json from pathlib import Path -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from transformers import PreTrainedTokenizer @@ -34,7 +35,12 @@ class ChatDataset: https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, chat_key: str = "messages"): + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + chat_key: str = "messages", + ): self._data = [ tokenizer.apply_chat_template( d[chat_key], @@ -42,7 +48,6 @@ class ChatDataset: ) for d in data ] - self._chat_key = chat_key def __getitem__(self, idx: int): return self._data[idx] @@ -82,48 +87,15 @@ class CompletionsDataset: return len(self._data) -class CompletionsDatasetCollection: - def __init__(self, data: List[Union[ChatDataset, CompletionsDataset]]): - self.collection = data - - def __fetch_and_process_item__(self, idx: int, handler_fn: Callable): - iteration = iter(self.collection) - item = next(iteration) - - curr_idx = idx - - while True: - try: - if (curr_idx + 1) <= len(item): - return handler_fn(item, curr_idx) - else: - curr_idx -= len(item) - item = next(iteration) - except StopIteration: - raise IndexError(idx) +class ConcatenatedDataset: + def __init__(self, data: List[Any]): + self._data = list(itertools.chain(*data)) def __getitem__(self, idx: int): - def getitem(dataset: CompletionsDataset, index: int): - return dataset[index] - - return self.__fetch_and_process_item__(idx, getitem) - - def get_item( - self, idx: int, tokenize: bool = False, add_generation_prompt: bool = True - ) -> str: - def getitem(dataset: CompletionsDataset, index: int): - return dataset.get_item(index, tokenize, add_generation_prompt) - - return self.__fetch_and_process_item__(idx, getitem) - - def get_prompt_and_completion(self, idx: int): - def getitem(dataset: CompletionsDataset, index: int): - return dataset.get_prompt_and_completion(index) - - return self.__fetch_and_process_item__(idx, getitem) + return self._data[idx] def __len__(self): - return sum(map(len, self.collection)) + return len(self._data) def create_dataset( @@ -206,11 +178,12 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): completion_feature, chat_feature, split, + config, ): ds = datasets.load_dataset( dataset_name, split=split, - **hf_args.get("config", {}), + **config, ) if prompt_feature and completion_feature: return CompletionsDataset(ds, tokenizer, prompt_feature, completion_feature) @@ -224,54 +197,68 @@ def load_custom_hf_dataset(args, tokenizer: PreTrainedTokenizer): " or a text feature for the Hugging Face dataset." ) - def get_train_and_valid_splits(hf_args, ds_name): - text_f = hf_args.get("text_feature", None) - prompt_f = hf_args.get("prompt_feature", None) - completion_f = hf_args.get("completion_feature", None) - chat_f = hf_args.get("chat_feature", None) + dataset_collection = args.hf_dataset + if isinstance(dataset_collection, dict): + dataset_collection = [dataset_collection] + + collection = [] + for ds in dataset_collection: + ds_name = ds["name"] + print(f"Loading Hugging Face dataset {ds_name}.") + text_f = ds.get("text_feature", None) + prompt_f = ds.get("prompt_feature", None) + completion_f = ds.get("completion_feature", None) + chat_f = ds.get("chat_feature", None) + ds_config = ds.get("config", {}) if args.train: - train_split = hf_args.get("train_split", "train[:80%]") - valid_split = hf_args.get("valid_split", "train[-10%:]") + train_split = ds.get("train_split", "train[:80%]") + valid_split = ds.get("valid_split", "train[-10%:]") train = create_hf_dataset( - ds_name, text_f, prompt_f, completion_f, chat_f, split=train_split + ds_name, + text_f, + prompt_f, + completion_f, + chat_f, + train_split, + ds_config, ) valid = create_hf_dataset( - ds_name, text_f, prompt_f, completion_f, chat_f, split=valid_split + ds_name, + text_f, + prompt_f, + completion_f, + chat_f, + valid_split, + ds_config, ) else: train, valid = [], [] if args.test: - test_split = hf_args.get("test_split") + test_split = ds.get("test_split") test = create_hf_dataset( - ds_name, text_f, prompt_f, completion_f, chat_f, split=test_split, + ds_name, + text_f, + prompt_f, + completion_f, + chat_f, + test_split, + ds_config, ) else: test = [] - return train, valid, test + collection.append((train, valid, test)) - if args.datasets: - dataset_collection = args.hf_datasets - else: - dataset_collection = {"hf_dataset": args.hf_dataset} - - datasets = [] - for ds in dataset_collection: - hf_args = ds["hf_dataset"] - dataset_name = hf_args["name"] - print(f"Loading Hugging Face dataset {dataset_name}.") - datasets.append(get_splits(hf_args, dataset_name)) - if len(datsets) == 1: - return *datasets + if len(collection) == 1: + return collection[0] # Otherwise concatenate them - train, valid, test = zip(*datasets) - return tuple(map, Concatenate, zip(*datasets)) + return tuple(map(ConcatenatedDataset, zip(*collection))) def load_dataset(args, tokenizer: PreTrainedTokenizer): - if getattr(args, "hf_dataset", False) or getattr(args, "hf_datasets", False): + if getattr(args, "hf_dataset", False): train, valid, test = load_custom_hf_dataset(args, tokenizer) else: data_path = Path(args.data) diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index dd86d277..5edab8bf 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -78,14 +78,15 @@ class TestDatasets(unittest.TestCase): self.assertTrue(isinstance(train, datasets.ChatDataset)) def test_hf(self): + hf_args = { + "name": "billsum", + "prompt_feature": "text", + "completion_feature": "summary", + "train_split": "train[:2%]", + "valid_split": "train[-2%:]", + } args = types.SimpleNamespace( - hf_dataset={ - "name": "billsum", - "prompt_feature": "text", - "completion_feature": "summary", - "train_split": "train[:2%]", - "valid_split": "train[-2%:]", - }, + hf_dataset=hf_args, test=False, train=True, ) @@ -97,6 +98,16 @@ class TestDatasets(unittest.TestCase): self.assertTrue(len(valid[0]) > 0) self.assertEqual(len(test), 0) + args = types.SimpleNamespace( + hf_dataset=[hf_args, hf_args], + test=False, + train=True, + ) + train_double, valid_double, test_double = datasets.load_dataset(args, tokenizer) + self.assertEqual(2 * len(train), len(train_double)) + self.assertEqual(2 * len(valid), len(valid_double)) + self.assertEqual(2 * len(test), len(test_double)) + if __name__ == "__main__": unittest.main()