Configuration-based use of HF hub-hosted datasets for training (#701)

* Add hf_dataset configuration for using HF hub-hosted datasets for (Q)LoRA training

* Pre-commit formatting

* Fix YAML config example

* Print DS info

* Include name

* Add hf_dataset parameter default

* Remove TextHFDataset and CompletionsHFDataset and use Dataset and CompletionsDataset instead, adding a text_key constructor argument to the former (and changing it to work with a provided data structure instead of just from a JSON file), and prompt_key and completion_key arguments to the latter with defaults for backwards compatibility.

* nits

* update docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-06-26 13:20:50 -04:00 committed by GitHub
parent 1d701a1831
commit df6bc09d74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 140 additions and 28 deletions

View File

@ -32,7 +32,7 @@ jobs:
pip install --upgrade pip pip install --upgrade pip
pip install unittest-xml-reporting pip install unittest-xml-reporting
cd llms/ cd llms/
pip install -e . pip install -e ".[testing]"
- run: - run:
name: Run Python tests name: Run Python tests
command: | command: |

View File

@ -151,9 +151,14 @@ Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format. correct format.
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
Face.
### Local Datasets
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory. loader expects a `test.jsonl` in the data directory.
Currently, `*.jsonl` files support three data formats: `chat`, Currently, `*.jsonl` files support three data formats: `chat`,
`completions`, and `text`. Here are three examples of these formats: `completions`, and `text`. Here are three examples of these formats:
@ -199,7 +204,34 @@ Currently, `*.jsonl` files support three data formats: `chat`,
Note, the format is automatically determined by the dataset. Note also, keys in Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored. each line not expected by the loader will be ignored.
For the `chat` and `completions` formats, Hugging Face [chat ### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
```
pip install datasets
```
Specify the Hugging Face dataset arguments in a YAML config. For example:
```
hf_dataset:
name: "billsum"
prompt_feature: "text"
completion_feature: "summary"
```
- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat` and `completions` formats, Hugging Face [chat
templates](https://huggingface.co/blog/chat-templates) are used. This applies templates](https://huggingface.co/blog/chat-templates) are used. This applies
the model's chat template by default. If the model does not have a chat the model's chat template by default. If the model does not have a chat
template, then Hugging Face will use a default. For example, the final text in template, then Hugging Face will use a default. For example, the final text in

View File

@ -69,3 +69,11 @@ lora_parameters:
# warmup: 100 # 0 for no warmup # warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified # warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler # arguments: [1e-5, 1000, 1e-7] # passed to scheduler
#hf_dataset:
# name: "billsum"
# train_split: "train[:1000]"
# valid_split: "train[-100:]"
# prompt_feature: "text"
# completion_feature: "summary"

View File

@ -1,20 +1,21 @@
import json import json
from pathlib import Path from pathlib import Path
from typing import Dict, List
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
class Dataset: class Dataset:
""" """
Light-weight wrapper to hold lines from a jsonl file Light-weight wrapper to hold a dataset.
""" """
def __init__(self, path: Path): def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
with open(path, "r") as fid: self._text_key = text_key
self._data = [json.loads(l) for l in fid] self._data = data
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return self._data[idx]["text"] return self._data[idx][self._text_key]
def __len__(self): def __len__(self):
if self._data is None: if self._data is None:
@ -28,8 +29,8 @@ class ChatDataset(Dataset):
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(path) super().__init__(data)
self._tokenizer = tokenizer self._tokenizer = tokenizer
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
@ -43,19 +44,28 @@ class ChatDataset(Dataset):
class CompletionsDataset(Dataset): class CompletionsDataset(Dataset):
""" """
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...} A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
https://platform.openai.com/docs/guides/fine-tuning/example-format https://platform.openai.com/docs/guides/fine-tuning/example-format
""" """
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer): def __init__(
super().__init__(path) self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
completion_key: str = "completion",
):
super().__init__(data)
self._tokenizer = tokenizer self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
data = self._data[idx] data = self._data[idx]
text = self._tokenizer.apply_chat_template( text = self._tokenizer.apply_chat_template(
[ [
{"role": "user", "content": data["prompt"]}, {"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data["completion"]}, {"role": "assistant", "content": data[self._completion_key]},
], ],
tokenize=False, tokenize=False,
add_generation_prompt=True, add_generation_prompt=True,
@ -68,14 +78,13 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
if not path.exists(): if not path.exists():
return [] return []
with open(path, "r") as fid: with open(path, "r") as fid:
first_line = next(fid) data = [json.loads(l) for l in fid]
first_obj = json.loads(first_line) if "messages" in data[0]:
if "messages" in first_obj: return ChatDataset(data, tokenizer)
return ChatDataset(path, tokenizer) elif "prompt" in data[0] and "completion" in data[0]:
elif "prompt" in first_obj and "completion" in first_obj: return CompletionsDataset(data, tokenizer)
return CompletionsDataset(path, tokenizer) elif "text" in data[0]:
elif "text" in first_obj: return Dataset(data)
return Dataset(path)
else: else:
raise ValueError( raise ValueError(
"Unsupported data format, check the supported formats here:\n" "Unsupported data format, check the supported formats here:\n"
@ -84,11 +93,53 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
def load_dataset(args, tokenizer: PreTrainedTokenizer): def load_dataset(args, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test") if getattr(args, "hf_dataset", None) is not None:
data_path = Path(args.data) import datasets
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names hf_args = args.hf_dataset
] dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
)
if prompt_feature and completion_feature:
return CompletionsDataset(
ds, tokenizer, prompt_feature, completion_feature
)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
)
if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
else:
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning." "Training set not found or empty. Must provide training set for fine-tuning."

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc. # Copyright © 2023-2024 Apple Inc.
__version__ = "0.15.0" __version__ = "0.16.0"

View File

@ -26,6 +26,9 @@ setup(
install_requires=requirements, install_requires=requirements,
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"], packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8", python_requires=">=3.8",
extras_require={
"testing": ["datasets"],
},
entry_points={ entry_points={
"console_scripts": [ "console_scripts": [
"mlx_lm.convert = mlx_lm.convert:main", "mlx_lm.convert = mlx_lm.convert:main",

View File

@ -76,6 +76,24 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0) self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.ChatDataset)) self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self):
args = types.SimpleNamespace(
hf_dataset={
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
},
test=False,
train=True,
)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertTrue(len(train) > 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()