mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Support for OpenAI’s fine-tuning dataset format (#548)
* LoRA: move load_dataset to tuner/datasets.py file * LoRA: support OpenAI chat format datasets see https://platform.openai.com/docs/guides/fine-tuning/example-format * LoRA: support OpenAI completion format datasets * LoRA: formatting dataset timing to reduce memory footprint * Refactor dataset item access in PromptCompletionDataset * Update mlx_lm/LORA.md * Update mlx_lm/LORA.md * check Unsupported data format * add tests, fine-tune doc * add tests, fine-tune doc * add jinja2 for chat template * nits in readme * nits in readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
e05e502c34
commit
b0bcd86a40
@ -136,14 +136,54 @@ correct format.
|
||||
|
||||
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||
loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl`
|
||||
file should look like:
|
||||
loader expects a `test.jsonl` in the data directory.
|
||||
|
||||
Currently, `*.jsonl` files support three data formats: `chat`,
|
||||
`completions`, and `text`. Here are three examples of these formats:
|
||||
|
||||
`chat`:
|
||||
|
||||
```jsonl
|
||||
{"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant." },
|
||||
{"role": "user", "content": "Hello."},
|
||||
{"role": "assistant", "content": "How can I assistant you today."},
|
||||
]}
|
||||
```
|
||||
|
||||
`completions`:
|
||||
|
||||
```jsonl
|
||||
{"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||
```
|
||||
|
||||
`text`:
|
||||
|
||||
```jsonl
|
||||
{"text": "This is an example for the model."}
|
||||
```
|
||||
|
||||
Note, other keys will be ignored by the loader.
|
||||
Note, the format is automatically determined by the dataset. Note also, keys in
|
||||
each line not expected by the loader will be ignored.
|
||||
|
||||
For the `chat` and `completions` formats, Hugging Face [chat
|
||||
templates](https://huggingface.co/blog/chat-templates) are used. This applies
|
||||
the model's chat template by default. If the model does not have a chat
|
||||
template, then Hugging Face will use a default. For example, the final text in
|
||||
the `chat` example above with Hugging Face's default template becomes:
|
||||
|
||||
```text
|
||||
<|im_start|>system
|
||||
You are a helpful assistant.<|im_end|>
|
||||
<|im_start|>user
|
||||
Hello.<|im_end|>
|
||||
<|im_start|>assistant
|
||||
How can I assistant you today.<|im_end|>
|
||||
```
|
||||
|
||||
If you are unsure of the format to use, the `chat` or `completions` are good to
|
||||
start with. For custom requirements on the format of the dataset, use the
|
||||
`text` format to assemble the content yourself.
|
||||
|
||||
## Memory Issues
|
||||
|
||||
|
@ -12,6 +12,7 @@ import numpy as np
|
||||
import yaml
|
||||
from mlx.utils import tree_flatten
|
||||
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.utils import linear_to_lora_layers
|
||||
from .utils import load
|
||||
@ -141,46 +142,6 @@ def build_parser():
|
||||
return parser
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold lines from a jsonl file
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, key: str = "text"):
|
||||
if not path.exists():
|
||||
self._data = None
|
||||
else:
|
||||
with open(path, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
self._key = key
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx][self._key]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
def load_dataset(args):
|
||||
names = ("train", "valid", "test")
|
||||
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||
)
|
||||
if args.train and len(valid) == 0:
|
||||
raise ValueError(
|
||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||
)
|
||||
if args.test and len(test) == 0:
|
||||
raise ValueError(
|
||||
"Test set not found or empty. Must provide test set for evaluation."
|
||||
)
|
||||
return train, valid, test
|
||||
|
||||
|
||||
def print_trainable_parameters(model):
|
||||
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||
trainable_p = (
|
||||
@ -206,7 +167,7 @@ def run(args, training_callback: TrainingCallback = None):
|
||||
print_trainable_parameters(model)
|
||||
|
||||
print("Loading datasets")
|
||||
train_set, valid_set, test_set = load_dataset(args)
|
||||
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||
|
||||
# Resume training the given adapters.
|
||||
if args.resume_adapter_file is not None:
|
||||
|
@ -3,3 +3,4 @@ numpy
|
||||
transformers>=4.38.0
|
||||
protobuf
|
||||
pyyaml
|
||||
jinja2
|
||||
|
104
llms/mlx_lm/tuner/datasets.py
Normal file
104
llms/mlx_lm/tuner/datasets.py
Normal file
@ -0,0 +1,104 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold lines from a jsonl file
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
with open(path, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]["text"]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ChatDataset(Dataset):
|
||||
"""
|
||||
A dataset for chat data in the format of {"messages": [...]}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(path)
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
messages = self._data[idx]["messages"]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
class CompletionsDataset(Dataset):
|
||||
"""
|
||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(path)
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
data = self._data[idx]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": data["prompt"]},
|
||||
{"role": "assistant", "content": data["completion"]},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
||||
# Return empty dataset for non-existent paths
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
first_line = next(fid)
|
||||
first_obj = json.loads(first_line)
|
||||
if "messages" in first_obj:
|
||||
return ChatDataset(path, tokenizer)
|
||||
elif "prompt" in first_obj and "completion" in first_obj:
|
||||
return CompletionsDataset(path, tokenizer)
|
||||
elif "text" in first_obj:
|
||||
return Dataset(path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
names = ("train", "valid", "test")
|
||||
data_path = Path(args.data)
|
||||
train, valid, test = [
|
||||
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
|
||||
]
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||
)
|
||||
if args.train and len(valid) == 0:
|
||||
raise ValueError(
|
||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||
)
|
||||
if args.test and len(test) == 0:
|
||||
raise ValueError(
|
||||
"Test set not found or empty. Must provide test set for evaluation."
|
||||
)
|
||||
return train, valid, test
|
81
llms/tests/test_datsets.py
Normal file
81
llms/tests/test_datsets.py
Normal file
@ -0,0 +1,81 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import types
|
||||
import unittest
|
||||
|
||||
from mlx_lm.tuner import datasets
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||
|
||||
|
||||
class TestDatasets(unittest.TestCase):
|
||||
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||
cls.test_dir = cls.test_dir_fid.name
|
||||
if not os.path.isdir(cls.test_dir):
|
||||
os.mkdir(cls.test_dir_fid.name)
|
||||
|
||||
@classmethod
|
||||
def tearDownClass(cls):
|
||||
cls.test_dir_fid.cleanup()
|
||||
|
||||
def save_data(self, data):
|
||||
for ds in ["train", "valid"]:
|
||||
with open(os.path.join(self.test_dir, f"{ds}.jsonl"), "w") as fid:
|
||||
for l in data:
|
||||
json.dump(l, fid)
|
||||
fid.write("\n")
|
||||
|
||||
def test_text(self):
|
||||
data = {"text": "This is an example for the model."}
|
||||
self.save_data(4 * [data])
|
||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||
train, valid, test = datasets.load_dataset(args, None)
|
||||
self.assertEqual(len(train), 4)
|
||||
self.assertEqual(len(valid), 4)
|
||||
self.assertEqual(len(test), 0)
|
||||
self.assertTrue(len(train[0]) > 0)
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertTrue(isinstance(train, datasets.Dataset))
|
||||
|
||||
def test_completions(self):
|
||||
data = {"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||
self.save_data(4 * [data])
|
||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(len(train), 4)
|
||||
self.assertEqual(len(valid), 4)
|
||||
self.assertEqual(len(test), 0)
|
||||
self.assertTrue(len(train[0]) > 0)
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertTrue(isinstance(train, datasets.CompletionsDataset))
|
||||
|
||||
def test_chat(self):
|
||||
data = {
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello."},
|
||||
{"role": "assistant", "content": "How can I assistant you today."},
|
||||
]
|
||||
}
|
||||
self.save_data(4 * [data])
|
||||
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||
self.assertEqual(len(train), 4)
|
||||
self.assertEqual(len(valid), 4)
|
||||
self.assertEqual(len(test), 0)
|
||||
self.assertTrue(len(train[0]) > 0)
|
||||
self.assertTrue(len(valid[0]) > 0)
|
||||
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
Loading…
Reference in New Issue
Block a user