From df6bc09d7471f05f7aec69b2bfa54290a60b22af Mon Sep 17 00:00:00 2001 From: Chime Ogbuji Date: Wed, 26 Jun 2024 13:20:50 -0400 Subject: [PATCH] Configuration-based use of HF hub-hosted datasets for training (#701) * Add hf_dataset configuration for using HF hub-hosted datasets for (Q)LoRA training * Pre-commit formatting * Fix YAML config example * Print DS info * Include name * Add hf_dataset parameter default * Remove TextHFDataset and CompletionsHFDataset and use Dataset and CompletionsDataset instead, adding a text_key constructor argument to the former (and changing it to work with a provided data structure instead of just from a JSON file), and prompt_key and completion_key arguments to the latter with defaults for backwards compatibility. * nits * update docs --------- Co-authored-by: Awni Hannun --- .circleci/config.yml | 2 +- llms/mlx_lm/LORA.md | 36 +++++++++- llms/mlx_lm/examples/lora_config.yaml | 8 +++ llms/mlx_lm/tuner/datasets.py | 99 ++++++++++++++++++++------- llms/mlx_lm/version.py | 2 +- llms/setup.py | 3 + llms/tests/test_datsets.py | 18 +++++ 7 files changed, 140 insertions(+), 28 deletions(-) diff --git a/.circleci/config.yml b/.circleci/config.yml index 556f209e..02fa1de8 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -32,7 +32,7 @@ jobs: pip install --upgrade pip pip install unittest-xml-reporting cd llms/ - pip install -e . + pip install -e ".[testing]" - run: name: Run Python tests command: | diff --git a/llms/mlx_lm/LORA.md b/llms/mlx_lm/LORA.md index 3d65f213..2e739d0f 100644 --- a/llms/mlx_lm/LORA.md +++ b/llms/mlx_lm/LORA.md @@ -151,9 +151,14 @@ Examples GitHub repo has an [example of the WikiSQL data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the correct format. +Datasets can be specified in `*.jsonl` files locally or loaded from Hugging +Face. + +### Local Datasets + For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data -loader expects a `test.jsonl` in the data directory. +loader expects a `test.jsonl` in the data directory. Currently, `*.jsonl` files support three data formats: `chat`, `completions`, and `text`. Here are three examples of these formats: @@ -199,7 +204,34 @@ Currently, `*.jsonl` files support three data formats: `chat`, Note, the format is automatically determined by the dataset. Note also, keys in each line not expected by the loader will be ignored. -For the `chat` and `completions` formats, Hugging Face [chat +### Hugging Face Datasets + +To use Hugging Face datasets, first install the `datasets` package: + +``` +pip install datasets +``` + +Specify the Hugging Face dataset arguments in a YAML config. For example: + +``` +hf_dataset: + name: "billsum" + prompt_feature: "text" + completion_feature: "summary" +``` + +- Use `prompt_feature` and `completion_feature` to specify keys for a + `completions` dataset. Use `text_feature` to specify the key for a `text` + dataset. + +- To specify the train, valid, or test splits, set the corresponding + `{train,valid,test}_split` argument. + +- Arguments specified in `config` will be passed as keyword arguments to + [`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset). + +In general, for the `chat` and `completions` formats, Hugging Face [chat templates](https://huggingface.co/blog/chat-templates) are used. This applies the model's chat template by default. If the model does not have a chat template, then Hugging Face will use a default. For example, the final text in diff --git a/llms/mlx_lm/examples/lora_config.yaml b/llms/mlx_lm/examples/lora_config.yaml index d3c0d22a..073a5b6f 100644 --- a/llms/mlx_lm/examples/lora_config.yaml +++ b/llms/mlx_lm/examples/lora_config.yaml @@ -69,3 +69,11 @@ lora_parameters: # warmup: 100 # 0 for no warmup # warmup_init: 1e-7 # 0 if not specified # arguments: [1e-5, 1000, 1e-7] # passed to scheduler + +#hf_dataset: +# name: "billsum" +# train_split: "train[:1000]" +# valid_split: "train[-100:]" +# prompt_feature: "text" +# completion_feature: "summary" + diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index e5776160..3d99894c 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -1,20 +1,21 @@ import json from pathlib import Path +from typing import Dict, List from transformers import PreTrainedTokenizer class Dataset: """ - Light-weight wrapper to hold lines from a jsonl file + Light-weight wrapper to hold a dataset. """ - def __init__(self, path: Path): - with open(path, "r") as fid: - self._data = [json.loads(l) for l in fid] + def __init__(self, data: List[Dict[str, str]], text_key: str = "text"): + self._text_key = text_key + self._data = data def __getitem__(self, idx: int): - return self._data[idx]["text"] + return self._data[idx][self._text_key] def __len__(self): if self._data is None: @@ -28,8 +29,8 @@ class ChatDataset(Dataset): https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): - super().__init__(path) + def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer): + super().__init__(data) self._tokenizer = tokenizer def __getitem__(self, idx: int): @@ -43,19 +44,28 @@ class ChatDataset(Dataset): class CompletionsDataset(Dataset): """ A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} + or using user-provided keys for prompt and completion values https://platform.openai.com/docs/guides/fine-tuning/example-format """ - def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): - super().__init__(path) + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + prompt_key: str = "prompt", + completion_key: str = "completion", + ): + super().__init__(data) self._tokenizer = tokenizer + self._prompt_key = prompt_key + self._completion_key = completion_key def __getitem__(self, idx: int): data = self._data[idx] text = self._tokenizer.apply_chat_template( [ - {"role": "user", "content": data["prompt"]}, - {"role": "assistant", "content": data["completion"]}, + {"role": "user", "content": data[self._prompt_key]}, + {"role": "assistant", "content": data[self._completion_key]}, ], tokenize=False, add_generation_prompt=True, @@ -68,14 +78,13 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): if not path.exists(): return [] with open(path, "r") as fid: - first_line = next(fid) - first_obj = json.loads(first_line) - if "messages" in first_obj: - return ChatDataset(path, tokenizer) - elif "prompt" in first_obj and "completion" in first_obj: - return CompletionsDataset(path, tokenizer) - elif "text" in first_obj: - return Dataset(path) + data = [json.loads(l) for l in fid] + if "messages" in data[0]: + return ChatDataset(data, tokenizer) + elif "prompt" in data[0] and "completion" in data[0]: + return CompletionsDataset(data, tokenizer) + elif "text" in data[0]: + return Dataset(data) else: raise ValueError( "Unsupported data format, check the supported formats here:\n" @@ -84,11 +93,53 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None): def load_dataset(args, tokenizer: PreTrainedTokenizer): - names = ("train", "valid", "test") - data_path = Path(args.data) - train, valid, test = [ - create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names - ] + if getattr(args, "hf_dataset", None) is not None: + import datasets + + hf_args = args.hf_dataset + dataset_name = hf_args["name"] + print(f"Loading Hugging Face dataset {dataset_name}.") + text_feature = hf_args.get("text_feature") + prompt_feature = hf_args.get("prompt_feature") + completion_feature = hf_args.get("completion_feature") + + def create_hf_dataset(split: str = None): + ds = datasets.load_dataset( + dataset_name, + split=split, + **hf_args.get("config", {}), + ) + if prompt_feature and completion_feature: + return CompletionsDataset( + ds, tokenizer, prompt_feature, completion_feature + ) + elif text_feature: + return Dataset(train_ds, text_key=text_feature) + else: + raise ValueError( + "Specify either a prompt and completion feature or a text " + "feature for the Hugging Face dataset." + ) + + if args.train: + train_split = hf_args.get("train_split", "train[:80%]") + valid_split = hf_args.get("valid_split", "train[-10%:]") + train = create_hf_dataset(split=train_split) + valid = create_hf_dataset(split=valid_split) + else: + train, valid = [], [] + if args.test: + test = create_hf_dataset(split=hf_args.get("test_split")) + else: + test = [] + + else: + names = ("train", "valid", "test") + data_path = Path(args.data) + + train, valid, test = [ + create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names + ] if args.train and len(train) == 0: raise ValueError( "Training set not found or empty. Must provide training set for fine-tuning." diff --git a/llms/mlx_lm/version.py b/llms/mlx_lm/version.py index 88c3e75e..40b73ede 100644 --- a/llms/mlx_lm/version.py +++ b/llms/mlx_lm/version.py @@ -1,3 +1,3 @@ # Copyright © 2023-2024 Apple Inc. -__version__ = "0.15.0" +__version__ = "0.16.0" diff --git a/llms/setup.py b/llms/setup.py index 648e1e04..88deed17 100644 --- a/llms/setup.py +++ b/llms/setup.py @@ -26,6 +26,9 @@ setup( install_requires=requirements, packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], python_requires=">=3.8", + extras_require={ + "testing": ["datasets"], + }, entry_points={ "console_scripts": [ "mlx_lm.convert = mlx_lm.convert:main", diff --git a/llms/tests/test_datsets.py b/llms/tests/test_datsets.py index 8d8c01a5..240bfb4a 100644 --- a/llms/tests/test_datsets.py +++ b/llms/tests/test_datsets.py @@ -76,6 +76,24 @@ class TestDatasets(unittest.TestCase): self.assertTrue(len(valid[0]) > 0) self.assertTrue(isinstance(train, datasets.ChatDataset)) + def test_hf(self): + args = types.SimpleNamespace( + hf_dataset={ + "name": "billsum", + "prompt_feature": "text", + "completion_feature": "summary", + }, + test=False, + train=True, + ) + tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH) + train, valid, test = datasets.load_dataset(args, tokenizer) + self.assertTrue(len(train) > 0) + self.assertTrue(len(train[0]) > 0) + self.assertTrue(len(valid) > 0) + self.assertTrue(len(valid[0]) > 0) + self.assertEqual(len(test), 0) + if __name__ == "__main__": unittest.main()