mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Configuration-based use of HF hub-hosted datasets for training (#701)
* Add hf_dataset configuration for using HF hub-hosted datasets for (Q)LoRA training * Pre-commit formatting * Fix YAML config example * Print DS info * Include name * Add hf_dataset parameter default * Remove TextHFDataset and CompletionsHFDataset and use Dataset and CompletionsDataset instead, adding a text_key constructor argument to the former (and changing it to work with a provided data structure instead of just from a JSON file), and prompt_key and completion_key arguments to the latter with defaults for backwards compatibility. * nits * update docs --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
1d701a1831
commit
df6bc09d74
@ -32,7 +32,7 @@ jobs:
|
|||||||
pip install --upgrade pip
|
pip install --upgrade pip
|
||||||
pip install unittest-xml-reporting
|
pip install unittest-xml-reporting
|
||||||
cd llms/
|
cd llms/
|
||||||
pip install -e .
|
pip install -e ".[testing]"
|
||||||
- run:
|
- run:
|
||||||
name: Run Python tests
|
name: Run Python tests
|
||||||
command: |
|
command: |
|
||||||
|
@ -151,6 +151,11 @@ Examples GitHub repo has an [example of the WikiSQL
|
|||||||
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
|
data](https://github.com/ml-explore/mlx-examples/tree/main/lora/data) in the
|
||||||
correct format.
|
correct format.
|
||||||
|
|
||||||
|
Datasets can be specified in `*.jsonl` files locally or loaded from Hugging
|
||||||
|
Face.
|
||||||
|
|
||||||
|
### Local Datasets
|
||||||
|
|
||||||
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||||
loader expects a `test.jsonl` in the data directory.
|
loader expects a `test.jsonl` in the data directory.
|
||||||
@ -199,7 +204,34 @@ Currently, `*.jsonl` files support three data formats: `chat`,
|
|||||||
Note, the format is automatically determined by the dataset. Note also, keys in
|
Note, the format is automatically determined by the dataset. Note also, keys in
|
||||||
each line not expected by the loader will be ignored.
|
each line not expected by the loader will be ignored.
|
||||||
|
|
||||||
For the `chat` and `completions` formats, Hugging Face [chat
|
### Hugging Face Datasets
|
||||||
|
|
||||||
|
To use Hugging Face datasets, first install the `datasets` package:
|
||||||
|
|
||||||
|
```
|
||||||
|
pip install datasets
|
||||||
|
```
|
||||||
|
|
||||||
|
Specify the Hugging Face dataset arguments in a YAML config. For example:
|
||||||
|
|
||||||
|
```
|
||||||
|
hf_dataset:
|
||||||
|
name: "billsum"
|
||||||
|
prompt_feature: "text"
|
||||||
|
completion_feature: "summary"
|
||||||
|
```
|
||||||
|
|
||||||
|
- Use `prompt_feature` and `completion_feature` to specify keys for a
|
||||||
|
`completions` dataset. Use `text_feature` to specify the key for a `text`
|
||||||
|
dataset.
|
||||||
|
|
||||||
|
- To specify the train, valid, or test splits, set the corresponding
|
||||||
|
`{train,valid,test}_split` argument.
|
||||||
|
|
||||||
|
- Arguments specified in `config` will be passed as keyword arguments to
|
||||||
|
[`datasets.load_dataset`](https://huggingface.co/docs/datasets/v2.20.0/en/package_reference/loading_methods#datasets.load_dataset).
|
||||||
|
|
||||||
|
In general, for the `chat` and `completions` formats, Hugging Face [chat
|
||||||
templates](https://huggingface.co/blog/chat-templates) are used. This applies
|
templates](https://huggingface.co/blog/chat-templates) are used. This applies
|
||||||
the model's chat template by default. If the model does not have a chat
|
the model's chat template by default. If the model does not have a chat
|
||||||
template, then Hugging Face will use a default. For example, the final text in
|
template, then Hugging Face will use a default. For example, the final text in
|
||||||
|
@ -69,3 +69,11 @@ lora_parameters:
|
|||||||
# warmup: 100 # 0 for no warmup
|
# warmup: 100 # 0 for no warmup
|
||||||
# warmup_init: 1e-7 # 0 if not specified
|
# warmup_init: 1e-7 # 0 if not specified
|
||||||
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
|
# arguments: [1e-5, 1000, 1e-7] # passed to scheduler
|
||||||
|
|
||||||
|
#hf_dataset:
|
||||||
|
# name: "billsum"
|
||||||
|
# train_split: "train[:1000]"
|
||||||
|
# valid_split: "train[-100:]"
|
||||||
|
# prompt_feature: "text"
|
||||||
|
# completion_feature: "summary"
|
||||||
|
|
||||||
|
@ -1,20 +1,21 @@
|
|||||||
import json
|
import json
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from typing import Dict, List
|
||||||
|
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset:
|
||||||
"""
|
"""
|
||||||
Light-weight wrapper to hold lines from a jsonl file
|
Light-weight wrapper to hold a dataset.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, path: Path):
|
def __init__(self, data: List[Dict[str, str]], text_key: str = "text"):
|
||||||
with open(path, "r") as fid:
|
self._text_key = text_key
|
||||||
self._data = [json.loads(l) for l in fid]
|
self._data = data
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return self._data[idx]["text"]
|
return self._data[idx][self._text_key]
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
@ -28,8 +29,8 @@ class ChatDataset(Dataset):
|
|||||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer):
|
||||||
super().__init__(path)
|
super().__init__(data)
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
@ -43,19 +44,28 @@ class ChatDataset(Dataset):
|
|||||||
class CompletionsDataset(Dataset):
|
class CompletionsDataset(Dataset):
|
||||||
"""
|
"""
|
||||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||||
|
or using user-provided keys for prompt and completion values
|
||||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
def __init__(
|
||||||
super().__init__(path)
|
self,
|
||||||
|
data: List[Dict[str, str]],
|
||||||
|
tokenizer: PreTrainedTokenizer,
|
||||||
|
prompt_key: str = "prompt",
|
||||||
|
completion_key: str = "completion",
|
||||||
|
):
|
||||||
|
super().__init__(data)
|
||||||
self._tokenizer = tokenizer
|
self._tokenizer = tokenizer
|
||||||
|
self._prompt_key = prompt_key
|
||||||
|
self._completion_key = completion_key
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
data = self._data[idx]
|
data = self._data[idx]
|
||||||
text = self._tokenizer.apply_chat_template(
|
text = self._tokenizer.apply_chat_template(
|
||||||
[
|
[
|
||||||
{"role": "user", "content": data["prompt"]},
|
{"role": "user", "content": data[self._prompt_key]},
|
||||||
{"role": "assistant", "content": data["completion"]},
|
{"role": "assistant", "content": data[self._completion_key]},
|
||||||
],
|
],
|
||||||
tokenize=False,
|
tokenize=False,
|
||||||
add_generation_prompt=True,
|
add_generation_prompt=True,
|
||||||
@ -68,14 +78,13 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
|||||||
if not path.exists():
|
if not path.exists():
|
||||||
return []
|
return []
|
||||||
with open(path, "r") as fid:
|
with open(path, "r") as fid:
|
||||||
first_line = next(fid)
|
data = [json.loads(l) for l in fid]
|
||||||
first_obj = json.loads(first_line)
|
if "messages" in data[0]:
|
||||||
if "messages" in first_obj:
|
return ChatDataset(data, tokenizer)
|
||||||
return ChatDataset(path, tokenizer)
|
elif "prompt" in data[0] and "completion" in data[0]:
|
||||||
elif "prompt" in first_obj and "completion" in first_obj:
|
return CompletionsDataset(data, tokenizer)
|
||||||
return CompletionsDataset(path, tokenizer)
|
elif "text" in data[0]:
|
||||||
elif "text" in first_obj:
|
return Dataset(data)
|
||||||
return Dataset(path)
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported data format, check the supported formats here:\n"
|
"Unsupported data format, check the supported formats here:\n"
|
||||||
@ -84,11 +93,53 @@ def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
|||||||
|
|
||||||
|
|
||||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
names = ("train", "valid", "test")
|
if getattr(args, "hf_dataset", None) is not None:
|
||||||
data_path = Path(args.data)
|
import datasets
|
||||||
train, valid, test = [
|
|
||||||
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
|
hf_args = args.hf_dataset
|
||||||
]
|
dataset_name = hf_args["name"]
|
||||||
|
print(f"Loading Hugging Face dataset {dataset_name}.")
|
||||||
|
text_feature = hf_args.get("text_feature")
|
||||||
|
prompt_feature = hf_args.get("prompt_feature")
|
||||||
|
completion_feature = hf_args.get("completion_feature")
|
||||||
|
|
||||||
|
def create_hf_dataset(split: str = None):
|
||||||
|
ds = datasets.load_dataset(
|
||||||
|
dataset_name,
|
||||||
|
split=split,
|
||||||
|
**hf_args.get("config", {}),
|
||||||
|
)
|
||||||
|
if prompt_feature and completion_feature:
|
||||||
|
return CompletionsDataset(
|
||||||
|
ds, tokenizer, prompt_feature, completion_feature
|
||||||
|
)
|
||||||
|
elif text_feature:
|
||||||
|
return Dataset(train_ds, text_key=text_feature)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Specify either a prompt and completion feature or a text "
|
||||||
|
"feature for the Hugging Face dataset."
|
||||||
|
)
|
||||||
|
|
||||||
|
if args.train:
|
||||||
|
train_split = hf_args.get("train_split", "train[:80%]")
|
||||||
|
valid_split = hf_args.get("valid_split", "train[-10%:]")
|
||||||
|
train = create_hf_dataset(split=train_split)
|
||||||
|
valid = create_hf_dataset(split=valid_split)
|
||||||
|
else:
|
||||||
|
train, valid = [], []
|
||||||
|
if args.test:
|
||||||
|
test = create_hf_dataset(split=hf_args.get("test_split"))
|
||||||
|
else:
|
||||||
|
test = []
|
||||||
|
|
||||||
|
else:
|
||||||
|
names = ("train", "valid", "test")
|
||||||
|
data_path = Path(args.data)
|
||||||
|
|
||||||
|
train, valid, test = [
|
||||||
|
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
|
||||||
|
]
|
||||||
if args.train and len(train) == 0:
|
if args.train and len(train) == 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||||
|
@ -1,3 +1,3 @@
|
|||||||
# Copyright © 2023-2024 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
__version__ = "0.15.0"
|
__version__ = "0.16.0"
|
||||||
|
@ -26,6 +26,9 @@ setup(
|
|||||||
install_requires=requirements,
|
install_requires=requirements,
|
||||||
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
|
packages=["mlx_lm", "mlx_lm.models", "mlx_lm.tuner"],
|
||||||
python_requires=">=3.8",
|
python_requires=">=3.8",
|
||||||
|
extras_require={
|
||||||
|
"testing": ["datasets"],
|
||||||
|
},
|
||||||
entry_points={
|
entry_points={
|
||||||
"console_scripts": [
|
"console_scripts": [
|
||||||
"mlx_lm.convert = mlx_lm.convert:main",
|
"mlx_lm.convert = mlx_lm.convert:main",
|
||||||
|
@ -76,6 +76,24 @@ class TestDatasets(unittest.TestCase):
|
|||||||
self.assertTrue(len(valid[0]) > 0)
|
self.assertTrue(len(valid[0]) > 0)
|
||||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||||
|
|
||||||
|
def test_hf(self):
|
||||||
|
args = types.SimpleNamespace(
|
||||||
|
hf_dataset={
|
||||||
|
"name": "billsum",
|
||||||
|
"prompt_feature": "text",
|
||||||
|
"completion_feature": "summary",
|
||||||
|
},
|
||||||
|
test=False,
|
||||||
|
train=True,
|
||||||
|
)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||||
|
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||||
|
self.assertTrue(len(train) > 0)
|
||||||
|
self.assertTrue(len(train[0]) > 0)
|
||||||
|
self.assertTrue(len(valid) > 0)
|
||||||
|
self.assertTrue(len(valid[0]) > 0)
|
||||||
|
self.assertEqual(len(test), 0)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
Loading…
Reference in New Issue
Block a user