mlx-examples/llms/mlx_lm/tuner/datasets.py
madroid b0bcd86a40
Support for OpenAI’s fine-tuning dataset format (#548)
* LoRA: move load_dataset to tuner/datasets.py file

* LoRA: support OpenAI chat format datasets

see https://platform.openai.com/docs/guides/fine-tuning/example-format

* LoRA: support OpenAI completion format datasets

* LoRA: formatting dataset timing to reduce memory footprint

* Refactor dataset item access in PromptCompletionDataset

* Update mlx_lm/LORA.md

* Update mlx_lm/LORA.md

* check Unsupported data format

* add tests, fine-tune doc

* add tests, fine-tune doc

* add jinja2 for chat template

* nits in readme

* nits in readme

---------

Co-authored-by: Awni Hannun <awni@apple.com>
2024-03-19 16:45:46 -07:00

105 lines
3.2 KiB
Python

import json
from pathlib import Path
from transformers import PreTrainedTokenizer
class Dataset:
"""
Light-weight wrapper to hold lines from a jsonl file
"""
def __init__(self, path: Path):
with open(path, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __getitem__(self, idx: int):
return self._data[idx]["text"]
def __len__(self):
if self._data is None:
return 0
return len(self._data)
class ChatDataset(Dataset):
"""
A dataset for chat data in the format of {"messages": [...]}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
messages = self._data[idx]["messages"]
text = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
return text
class CompletionsDataset(Dataset):
"""
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
https://platform.openai.com/docs/guides/fine-tuning/example-format
"""
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
super().__init__(path)
self._tokenizer = tokenizer
def __getitem__(self, idx: int):
data = self._data[idx]
text = self._tokenizer.apply_chat_template(
[
{"role": "user", "content": data["prompt"]},
{"role": "assistant", "content": data["completion"]},
],
tokenize=False,
add_generation_prompt=True,
)
return text
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
# Return empty dataset for non-existent paths
if not path.exists():
return []
with open(path, "r") as fid:
first_line = next(fid)
first_obj = json.loads(first_line)
if "messages" in first_obj:
return ChatDataset(path, tokenizer)
elif "prompt" in first_obj and "completion" in first_obj:
return CompletionsDataset(path, tokenizer)
elif "text" in first_obj:
return Dataset(path)
else:
raise ValueError(
"Unsupported data format, check the supported formats here:\n"
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
)
def load_dataset(args, tokenizer: PreTrainedTokenizer):
names = ("train", "valid", "test")
data_path = Path(args.data)
train, valid, test = [
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
]
if args.train and len(train) == 0:
raise ValueError(
"Training set not found or empty. Must provide training set for fine-tuning."
)
if args.train and len(valid) == 0:
raise ValueError(
"Validation set not found or empty. Must provide validation set for fine-tuning."
)
if args.test and len(test) == 0:
raise ValueError(
"Test set not found or empty. Must provide test set for evaluation."
)
return train, valid, test