Configuration-based use of HF hub-hosted datasets for training (#701)

* Add hf_dataset configuration for using HF hub-hosted datasets for (Q)LoRA training

* Pre-commit formatting

* Fix YAML config example

* Print DS info

* Include name

* Add hf_dataset parameter default

* Remove TextHFDataset and CompletionsHFDataset and use Dataset and CompletionsDataset instead, adding a text_key constructor argument to the former (and changing it to work with a provided data structure instead of just from a JSON file), and prompt_key and completion_key arguments to the latter with defaults for backwards compatibility.

* nits

* update docs

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
Chime Ogbuji 2024-06-26 13:20:50 -04:00 committed by GitHub
parent 1d701a1831
commit df6bc09d74
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 140 additions and 28 deletions

View File

@ -32,7 +32,7 @@ jobs:
pip install --upgrade pip
pip install unittest-xml-reporting
cd llms/
pip install -e .
pip install -e ".[testing]"
- run:
name: Run Python tests
command: |

View File

@ -151,6 +151,11 @@ Examples GitHub repo has an [example of the WikiSQL
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
correct format.
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
Face.
### Local Datasets
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory.
@ -199,7 +204,34 @@ Currently, `*.jsonl` files support three data formats: `chat`,
Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
For the `chat` and `completions` formats, Hugging Face [chat
### Hugging Face Datasets
To use Hugging Face datasets, first install the `datasets` package:
```
pip install datasets
```
Specify the Hugging Face dataset arguments in a YAML config. For example:
```
hf_dataset:
name: "billsum"
prompt_feature: "text"
completion_feature: "summary"
```
- Use `prompt_feature` and `completion_feature` to specify keys for a
`completions` dataset. Use `text_feature` to specify the key for a `text`
dataset.
- To specify the train, valid, or test splits, set the corresponding
`{train,valid,test}_split` argument.
- Arguments specified in `config` will be passed as keyword arguments to
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
In general, for the `chat` and `completions` formats, Hugging Face [chat
templates](https://huggingface.co/blog/chat-templates) are used. This applies
the model's chat template by default. If the model does not have a chat
template, then Hugging Face will use a default. For example, the final text in

View File

@ -69,3 +69,11 @@ lora_parameters:
# warmup: 100 # 0 for no warmup
# warmup_init: 1e-7 # 0 if not specified
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
#hf_dataset:
# name: "billsum"
# train_split: "train[:1000]"
# valid_split: "train[-100:]"
# prompt_feature: "text"
# completion_feature: "summary"

View File

@ -1,20 +1,21 @@
import json
from pathlib import Path
from typing import Dict, List
from transformers import PreTrainedTokenizer
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
Light-weight wrapper to hold a dataset.
"""
def __init__(self, path: Path):
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
self._text_key = text_key
self._data = data
def __getitem__(self, idx: int):
return self._data[idx]["text"]
return self._data[idx][self._text_key]
def __len__(self):
if self._data is None:
@ -28,8 +29,8 @@ class ChatDataset(Dataset):
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
super().__init__(data)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
@ -43,19 +44,28 @@ class ChatDataset(Dataset):
class CompletionsDataset(Dataset):
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
or using user-provided keys for prompt and completion values
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
completion_key: str = "completion",
):
super().__init__(data)
self._tokenizer = tokenizer
self._prompt_key = prompt_key
self._completion_key = completion_key
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data["prompt"]},
{"role": "assistant", "content": data["completion"]},
{"role": "user", "content": data[self._prompt_key]},
{"role": "assistant", "content": data[self._completion_key]},
],
tokenize=False,
add_generation_prompt=True,
@ -68,14 +78,13 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
if not path.exists():
return []
with open(path, "r") as fid:
first_line = next(fid)
first_obj = json.loads(first_line)
if "messages" in first_obj:
return ChatDataset(path, tokenizer)
elif "prompt" in first_obj and "completion" in first_obj:
return CompletionsDataset(path, tokenizer)
elif "text" in first_obj:
return Dataset(path)
data = [json.loads(l) for l in fid]
if "messages" in data[0]:
return ChatDataset(data, tokenizer)
elif "prompt" in data[0] and "completion" in data[0]:
return CompletionsDataset(data, tokenizer)
elif "text" in data[0]:
return Dataset(data)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
@ -84,11 +93,53 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
def load_dataset(args, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if getattr(args, "hf_dataset", None) is not None:
import datasets
hf_args = args.hf_dataset
dataset_name = hf_args["name"]
print(f"Loading Hugging Face dataset {dataset_name}.")
text_feature = hf_args.get("text_feature")
prompt_feature = hf_args.get("prompt_feature")
completion_feature = hf_args.get("completion_feature")
def create_hf_dataset(split: str = None):
ds = datasets.load_dataset(
dataset_name,
split=split,
**hf_args.get("config", {}),
)
if prompt_feature and completion_feature:
return CompletionsDataset(
ds, tokenizer, prompt_feature, completion_feature
)
elif text_feature:
return Dataset(train_ds, text_key=text_feature)
else:
raise ValueError(
"Specify either a prompt and completion feature or a text "
"feature for the Hugging Face dataset."
)
if args.train:
train_split = hf_args.get("train_split", "train[:80%]")
valid_split = hf_args.get("valid_split", "train[-10%:]")
train = create_hf_dataset(split=train_split)
valid = create_hf_dataset(split=valid_split)
else:
train, valid = [], []
if args.test:
test = create_hf_dataset(split=hf_args.get("test_split"))
else:
test = []
else:
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."

View File

@ -1,3 +1,3 @@
# Copyright © 2023-2024 Apple Inc.
__version__ = "0.15.0"
__version__ = "0.16.0"

View File

@ -26,6 +26,9 @@ setup(
install_requires=requirements,
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
python_requires=">=3.8",
extras_require={
"testing": ["datasets"],
},
entry_points={
"console_scripts": [
"mlx_lm.convert = mlx_lm.convert:main",

View File

@ -76,6 +76,24 @@ class TestDatasets(unittest.TestCase):
self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.ChatDataset))
def test_hf(self):
args = types.SimpleNamespace(
hf_dataset={
"name": "billsum",
"prompt_feature": "text",
"completion_feature": "summary",
},
test=False,
train=True,
)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertTrue(len(train) > 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertEqual(len(test), 0)
if __name__ == "__main__":
unittest.main()