Support for OpenAI’s fine-tuning dataset format (#548)

* LoRA: move load_dataset to tuner/datasets.py file

* LoRA: support OpenAI chat format datasets

see https://platform.openai.com/docs/guides/fine-tuning/example-format

* LoRA: support OpenAI completion format datasets

* LoRA: formatting dataset timing to reduce memory footprint

* Refactor dataset item access in PromptCompletionDataset

* Update mlx_lm/LORA.md

* Update mlx_lm/LORA.md

* check Unsupported data format

* add tests, fine-tune doc

* add tests, fine-tune doc

* add jinja2 for chat template

* nits in readme

* nits in readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
madroid 2024-03-20 07:45:46 +08:00 committed by GitHub
parent e05e502c34
commit b0bcd86a40
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 231 additions and 44 deletions

View File

@ -136,14 +136,54 @@ correct format.
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data `valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl` loader expects a `test.jsonl` in the data directory.
file should look like:
Currently, `*.jsonl` files support three data formats: `chat`,
`completions`, and `text`. Here are three examples of these formats:
`chat`:
```jsonl
{"messages": [
{"role": "system", "content": "You are a helpful assistant." },
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "How can I assistant you today."},
]}
``` ```
`completions`:
```jsonl
{"prompt": "What is the capital of France?", "completion": "Paris."}
```
`text`:
```jsonl
{"text": "This is an example for the model."} {"text": "This is an example for the model."}
``` ```
Note, other keys will be ignored by the loader. Note, the format is automatically determined by the dataset. Note also, keys in
each line not expected by the loader will be ignored.
For the `chat` and `completions` formats, Hugging Face [chat
templates](https://huggingface.co/blog/chat-templates) are used. This applies
the model's chat template by default. If the model does not have a chat
template, then Hugging Face will use a default. For example, the final text in
the `chat` example above with Hugging Face's default template becomes:
```text
<|im_start|>system
You are a helpful assistant.<|im_end|>
<|im_start|>user
Hello.<|im_end|>
<|im_start|>assistant
How can I assistant you today.<|im_end|>
```
If you are unsure of the format to use, the `chat` or `completions` are good to
start with. For custom requirements on the format of the dataset, use the
`text` format to assemble the content yourself.
## Memory Issues ## Memory Issues

View File

@ -12,6 +12,7 @@ import numpy as np
import yaml import yaml
from mlx.utils import tree_flatten from mlx.utils import tree_flatten
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.utils import linear_to_lora_layers from .tuner.utils import linear_to_lora_layers
from .utils import load from .utils import load
@ -141,46 +142,6 @@ def build_parser():
return parser return parser
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
"""
def __init__(self, path: Path, key: str = "text"):
if not path.exists():
self._data = None
else:
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
self._key = key
def __getitem__(self, idx: int):
return self._data[idx][self._key]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
def load_dataset(args):
names = ("train", "valid", "test")
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."
)
if args.train and len(valid) == 0:
raise ValueError(
"Validation set not found or empty. Must provide validation set for fine-tuning."
)
if args.test and len(test) == 0:
raise ValueError(
"Test set not found or empty. Must provide test set for evaluation."
)
return train, valid, test
def print_trainable_parameters(model): def print_trainable_parameters(model):
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
trainable_p = ( trainable_p = (
@ -206,7 +167,7 @@ def run(args, training_callback: TrainingCallback = None):
print_trainable_parameters(model) print_trainable_parameters(model)
print("Loading datasets") print("Loading datasets")
train_set, valid_set, test_set = load_dataset(args) train_set, valid_set, test_set = load_dataset(args, tokenizer)
# Resume training the given adapters. # Resume training the given adapters.
if args.resume_adapter_file is not None: if args.resume_adapter_file is not None:

View File

@ -3,3 +3,4 @@ numpy
transformers>=4.38.0 transformers>=4.38.0
protobuf protobuf
pyyaml pyyaml
jinja2

View File

@ -0,0 +1,104 @@
import json
from pathlib import Path
from transformers import PreTrainedTokenizer
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
"""
def __init__(self, path: Path):
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __getitem__(self, idx: int):
return self._data[idx]["text"]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
class ChatDataset(Dataset):
"""
A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return text
class CompletionsDataset(Dataset):
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data["prompt"]},
{"role": "assistant", "content": data["completion"]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
# Return empty dataset for non-existent paths
if not path.exists():
return []
with open(path, "r") as fid:
first_line = next(fid)
first_obj = json.loads(first_line)
if "messages" in first_obj:
return ChatDataset(path, tokenizer)
elif "prompt" in first_obj and "completion" in first_obj:
return CompletionsDataset(path, tokenizer)
elif "text" in first_obj:
return Dataset(path)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
)
def load_dataset(args, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."
)
if args.train and len(valid) == 0:
raise ValueError(
"Validation set not found or empty. Must provide validation set for fine-tuning."
)
if args.test and len(test) == 0:
raise ValueError(
"Test set not found or empty. Must provide test set for evaluation."
)
return train, valid, test

View File

@ -0,0 +1,81 @@
# Copyright © 2024 Apple Inc.
import json
import os
import tempfile
import types
import unittest
from mlx_lm.tuner import datasets
from transformers import AutoTokenizer
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
class TestDatasets(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.test_dir_fid = tempfile.TemporaryDirectory()
cls.test_dir = cls.test_dir_fid.name
if not os.path.isdir(cls.test_dir):
os.mkdir(cls.test_dir_fid.name)
@classmethod
def tearDownClass(cls):
cls.test_dir_fid.cleanup()
def save_data(self, data):
for ds in ["train", "valid"]:
with open(os.path.join(self.test_dir, f"{ds}.jsonl"), "w") as fid:
for l in data:
json.dump(l, fid)
fid.write("\n")
def test_text(self):
data = {"text": "This is an example for the model."}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
train, valid, test = datasets.load_dataset(args, None)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.Dataset))
def test_completions(self):
data = {"prompt": "What is the capital of France?", "completion": "Paris."}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.CompletionsDataset))
def test_chat(self):
data = {
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello."},
{"role": "assistant", "content": "How can I assistant you today."},
]
}
self.save_data(4 * [data])
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
train, valid, test = datasets.load_dataset(args, tokenizer)
self.assertEqual(len(train), 4)
self.assertEqual(len(valid), 4)
self.assertEqual(len(test), 0)
self.assertTrue(len(train[0]) > 0)
self.assertTrue(len(valid[0]) > 0)
self.assertTrue(isinstance(train, datasets.ChatDataset))
if __name__ == "__main__":
unittest.main()