mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Support for OpenAI’s fine-tuning dataset format (#548)
* LoRA: move load_dataset to tuner/datasets.py file * LoRA: support OpenAI chat format datasets see https://platform.openai.com/docs/guides/fine-tuning/example-format * LoRA: support OpenAI completion format datasets * LoRA: formatting dataset timing to reduce memory footprint * Refactor dataset item access in PromptCompletionDataset * Update mlx_lm/LORA.md * Update mlx_lm/LORA.md * check Unsupported data format * add tests, fine-tune doc * add tests, fine-tune doc * add jinja2 for chat template * nits in readme * nits in readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
parent
e05e502c34
commit
b0bcd86a40
@ -136,14 +136,54 @@ correct format.
|
|||||||
|
|
||||||
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
For fine-tuning (`--train`), the data loader expects a `train.jsonl` and a
|
||||||
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
`valid.jsonl` to be in the data directory. For evaluation (`--test`), the data
|
||||||
loader expects a `test.jsonl` in the data directory. Each line in the `*.jsonl`
|
loader expects a `test.jsonl` in the data directory.
|
||||||
file should look like:
|
|
||||||
|
|
||||||
|
Currently, `*.jsonl` files support three data formats: `chat`,
|
||||||
|
`completions`, and `text`. Here are three examples of these formats:
|
||||||
|
|
||||||
|
`chat`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant." },
|
||||||
|
{"role": "user", "content": "Hello."},
|
||||||
|
{"role": "assistant", "content": "How can I assistant you today."},
|
||||||
|
]}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
`completions`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
|
{"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||||
|
```
|
||||||
|
|
||||||
|
`text`:
|
||||||
|
|
||||||
|
```jsonl
|
||||||
{"text": "This is an example for the model."}
|
{"text": "This is an example for the model."}
|
||||||
```
|
```
|
||||||
|
|
||||||
Note, other keys will be ignored by the loader.
|
Note, the format is automatically determined by the dataset. Note also, keys in
|
||||||
|
each line not expected by the loader will be ignored.
|
||||||
|
|
||||||
|
For the `chat` and `completions` formats, Hugging Face [chat
|
||||||
|
templates](https://huggingface.co/blog/chat-templates) are used. This applies
|
||||||
|
the model's chat template by default. If the model does not have a chat
|
||||||
|
template, then Hugging Face will use a default. For example, the final text in
|
||||||
|
the `chat` example above with Hugging Face's default template becomes:
|
||||||
|
|
||||||
|
```text
|
||||||
|
<|im_start|>system
|
||||||
|
You are a helpful assistant.<|im_end|>
|
||||||
|
<|im_start|>user
|
||||||
|
Hello.<|im_end|>
|
||||||
|
<|im_start|>assistant
|
||||||
|
How can I assistant you today.<|im_end|>
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are unsure of the format to use, the `chat` or `completions` are good to
|
||||||
|
start with. For custom requirements on the format of the dataset, use the
|
||||||
|
`text` format to assemble the content yourself.
|
||||||
|
|
||||||
## Memory Issues
|
## Memory Issues
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ import numpy as np
|
|||||||
import yaml
|
import yaml
|
||||||
from mlx.utils import tree_flatten
|
from mlx.utils import tree_flatten
|
||||||
|
|
||||||
|
from .tuner.datasets import load_dataset
|
||||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||||
from .tuner.utils import linear_to_lora_layers
|
from .tuner.utils import linear_to_lora_layers
|
||||||
from .utils import load
|
from .utils import load
|
||||||
@ -141,46 +142,6 @@ def build_parser():
|
|||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
|
||||||
class Dataset:
|
|
||||||
"""
|
|
||||||
Light-weight wrapper to hold lines from a jsonl file
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, path: Path, key: str = "text"):
|
|
||||||
if not path.exists():
|
|
||||||
self._data = None
|
|
||||||
else:
|
|
||||||
with open(path, "r") as fid:
|
|
||||||
self._data = [json.loads(l) for l in fid]
|
|
||||||
self._key = key
|
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
|
||||||
return self._data[idx][self._key]
|
|
||||||
|
|
||||||
def __len__(self):
|
|
||||||
if self._data is None:
|
|
||||||
return 0
|
|
||||||
return len(self._data)
|
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(args):
|
|
||||||
names = ("train", "valid", "test")
|
|
||||||
train, valid, test = (Dataset(Path(args.data) / f"{n}.jsonl") for n in names)
|
|
||||||
if args.train and len(train) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
|
||||||
)
|
|
||||||
if args.train and len(valid) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
|
||||||
)
|
|
||||||
if args.test and len(test) == 0:
|
|
||||||
raise ValueError(
|
|
||||||
"Test set not found or empty. Must provide test set for evaluation."
|
|
||||||
)
|
|
||||||
return train, valid, test
|
|
||||||
|
|
||||||
|
|
||||||
def print_trainable_parameters(model):
|
def print_trainable_parameters(model):
|
||||||
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
total_p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6
|
||||||
trainable_p = (
|
trainable_p = (
|
||||||
@ -206,7 +167,7 @@ def run(args, training_callback: TrainingCallback = None):
|
|||||||
print_trainable_parameters(model)
|
print_trainable_parameters(model)
|
||||||
|
|
||||||
print("Loading datasets")
|
print("Loading datasets")
|
||||||
train_set, valid_set, test_set = load_dataset(args)
|
train_set, valid_set, test_set = load_dataset(args, tokenizer)
|
||||||
|
|
||||||
# Resume training the given adapters.
|
# Resume training the given adapters.
|
||||||
if args.resume_adapter_file is not None:
|
if args.resume_adapter_file is not None:
|
||||||
|
@ -3,3 +3,4 @@ numpy
|
|||||||
transformers>=4.38.0
|
transformers>=4.38.0
|
||||||
protobuf
|
protobuf
|
||||||
pyyaml
|
pyyaml
|
||||||
|
jinja2
|
||||||
|
104
llms/mlx_lm/tuner/datasets.py
Normal file
104
llms/mlx_lm/tuner/datasets.py
Normal file
@ -0,0 +1,104 @@
|
|||||||
|
import json
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Dataset:
|
||||||
|
"""
|
||||||
|
Light-weight wrapper to hold lines from a jsonl file
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path):
|
||||||
|
with open(path, "r") as fid:
|
||||||
|
self._data = [json.loads(l) for l in fid]
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
return self._data[idx]["text"]
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
if self._data is None:
|
||||||
|
return 0
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
|
|
||||||
|
class ChatDataset(Dataset):
|
||||||
|
"""
|
||||||
|
A dataset for chat data in the format of {"messages": [...]}
|
||||||
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
||||||
|
super().__init__(path)
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
messages = self._data[idx]["messages"]
|
||||||
|
text = self._tokenizer.apply_chat_template(
|
||||||
|
messages, tokenize=False, add_generation_prompt=True
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
class CompletionsDataset(Dataset):
|
||||||
|
"""
|
||||||
|
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||||
|
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
||||||
|
super().__init__(path)
|
||||||
|
self._tokenizer = tokenizer
|
||||||
|
|
||||||
|
def __getitem__(self, idx: int):
|
||||||
|
data = self._data[idx]
|
||||||
|
text = self._tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": data["prompt"]},
|
||||||
|
{"role": "assistant", "content": data["completion"]},
|
||||||
|
],
|
||||||
|
tokenize=False,
|
||||||
|
add_generation_prompt=True,
|
||||||
|
)
|
||||||
|
return text
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
||||||
|
# Return empty dataset for non-existent paths
|
||||||
|
if not path.exists():
|
||||||
|
return []
|
||||||
|
with open(path, "r") as fid:
|
||||||
|
first_line = next(fid)
|
||||||
|
first_obj = json.loads(first_line)
|
||||||
|
if "messages" in first_obj:
|
||||||
|
return ChatDataset(path, tokenizer)
|
||||||
|
elif "prompt" in first_obj and "completion" in first_obj:
|
||||||
|
return CompletionsDataset(path, tokenizer)
|
||||||
|
elif "text" in first_obj:
|
||||||
|
return Dataset(path)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported data format, check the supported formats here:\n"
|
||||||
|
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||||
|
names = ("train", "valid", "test")
|
||||||
|
data_path = Path(args.data)
|
||||||
|
train, valid, test = [
|
||||||
|
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
|
||||||
|
]
|
||||||
|
if args.train and len(train) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||||
|
)
|
||||||
|
if args.train and len(valid) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||||
|
)
|
||||||
|
if args.test and len(test) == 0:
|
||||||
|
raise ValueError(
|
||||||
|
"Test set not found or empty. Must provide test set for evaluation."
|
||||||
|
)
|
||||||
|
return train, valid, test
|
81
llms/tests/test_datsets.py
Normal file
81
llms/tests/test_datsets.py
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
# Copyright © 2024 Apple Inc.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import tempfile
|
||||||
|
import types
|
||||||
|
import unittest
|
||||||
|
|
||||||
|
from mlx_lm.tuner import datasets
|
||||||
|
from transformers import AutoTokenizer
|
||||||
|
|
||||||
|
HF_MODEL_PATH = "mlx-community/Qwen1.5-0.5B-Chat-4bit"
|
||||||
|
|
||||||
|
|
||||||
|
class TestDatasets(unittest.TestCase):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def setUpClass(cls):
|
||||||
|
cls.test_dir_fid = tempfile.TemporaryDirectory()
|
||||||
|
cls.test_dir = cls.test_dir_fid.name
|
||||||
|
if not os.path.isdir(cls.test_dir):
|
||||||
|
os.mkdir(cls.test_dir_fid.name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def tearDownClass(cls):
|
||||||
|
cls.test_dir_fid.cleanup()
|
||||||
|
|
||||||
|
def save_data(self, data):
|
||||||
|
for ds in ["train", "valid"]:
|
||||||
|
with open(os.path.join(self.test_dir, f"{ds}.jsonl"), "w") as fid:
|
||||||
|
for l in data:
|
||||||
|
json.dump(l, fid)
|
||||||
|
fid.write("\n")
|
||||||
|
|
||||||
|
def test_text(self):
|
||||||
|
data = {"text": "This is an example for the model."}
|
||||||
|
self.save_data(4 * [data])
|
||||||
|
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||||
|
train, valid, test = datasets.load_dataset(args, None)
|
||||||
|
self.assertEqual(len(train), 4)
|
||||||
|
self.assertEqual(len(valid), 4)
|
||||||
|
self.assertEqual(len(test), 0)
|
||||||
|
self.assertTrue(len(train[0]) > 0)
|
||||||
|
self.assertTrue(len(valid[0]) > 0)
|
||||||
|
self.assertTrue(isinstance(train, datasets.Dataset))
|
||||||
|
|
||||||
|
def test_completions(self):
|
||||||
|
data = {"prompt": "What is the capital of France?", "completion": "Paris."}
|
||||||
|
self.save_data(4 * [data])
|
||||||
|
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||||
|
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||||
|
self.assertEqual(len(train), 4)
|
||||||
|
self.assertEqual(len(valid), 4)
|
||||||
|
self.assertEqual(len(test), 0)
|
||||||
|
self.assertTrue(len(train[0]) > 0)
|
||||||
|
self.assertTrue(len(valid[0]) > 0)
|
||||||
|
self.assertTrue(isinstance(train, datasets.CompletionsDataset))
|
||||||
|
|
||||||
|
def test_chat(self):
|
||||||
|
data = {
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello."},
|
||||||
|
{"role": "assistant", "content": "How can I assistant you today."},
|
||||||
|
]
|
||||||
|
}
|
||||||
|
self.save_data(4 * [data])
|
||||||
|
args = types.SimpleNamespace(train=True, test=False, data=self.test_dir)
|
||||||
|
tokenizer = AutoTokenizer.from_pretrained(HF_MODEL_PATH)
|
||||||
|
train, valid, test = datasets.load_dataset(args, tokenizer)
|
||||||
|
self.assertEqual(len(train), 4)
|
||||||
|
self.assertEqual(len(valid), 4)
|
||||||
|
self.assertEqual(len(test), 0)
|
||||||
|
self.assertTrue(len(train[0]) > 0)
|
||||||
|
self.assertTrue(len(valid[0]) > 0)
|
||||||
|
self.assertTrue(isinstance(train, datasets.ChatDataset))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
Loading…
Reference in New Issue
Block a user