mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-01 04:14:38 +08:00
Support for OpenAI’s fine-tuning dataset format (#548)
* LoRA: move load_dataset to tuner/datasets.py file * LoRA: support OpenAI chat format datasets see https://platform.openai.com/docs/guides/fine-tuning/example-format * LoRA: support OpenAI completion format datasets * LoRA: formatting dataset timing to reduce memory footprint * Refactor dataset item access in PromptCompletionDataset * Update mlx_lm/LORA.md * Update mlx_lm/LORA.md * check Unsupported data format * add tests, fine-tune doc * add tests, fine-tune doc * add jinja2 for chat template * nits in readme * nits in readme --------- Co-authored-by: Awni Hannun <awni@apple.com>
This commit is contained in:
104
llms/mlx_lm/tuner/datasets.py
Normal file
104
llms/mlx_lm/tuner/datasets.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold lines from a jsonl file
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path):
|
||||
with open(path, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return self._data[idx]["text"]
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class ChatDataset(Dataset):
|
||||
"""
|
||||
A dataset for chat data in the format of {"messages": [...]}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(path)
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
messages = self._data[idx]["messages"]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
messages, tokenize=False, add_generation_prompt=True
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
class CompletionsDataset(Dataset):
|
||||
"""
|
||||
A dataset for prompt-completion data in the format of {"prompt": ..., "completion": ...}
|
||||
https://platform.openai.com/docs/guides/fine-tuning/example-format
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path, tokenizer: PreTrainedTokenizer):
|
||||
super().__init__(path)
|
||||
self._tokenizer = tokenizer
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
data = self._data[idx]
|
||||
text = self._tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": data["prompt"]},
|
||||
{"role": "assistant", "content": data["completion"]},
|
||||
],
|
||||
tokenize=False,
|
||||
add_generation_prompt=True,
|
||||
)
|
||||
return text
|
||||
|
||||
|
||||
def create_dataset(path: Path, tokenizer: PreTrainedTokenizer = None):
|
||||
# Return empty dataset for non-existent paths
|
||||
if not path.exists():
|
||||
return []
|
||||
with open(path, "r") as fid:
|
||||
first_line = next(fid)
|
||||
first_obj = json.loads(first_line)
|
||||
if "messages" in first_obj:
|
||||
return ChatDataset(path, tokenizer)
|
||||
elif "prompt" in first_obj and "completion" in first_obj:
|
||||
return CompletionsDataset(path, tokenizer)
|
||||
elif "text" in first_obj:
|
||||
return Dataset(path)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported data format, check the supported formats here:\n"
|
||||
"https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/LORA.md#data."
|
||||
)
|
||||
|
||||
|
||||
def load_dataset(args, tokenizer: PreTrainedTokenizer):
|
||||
names = ("train", "valid", "test")
|
||||
data_path = Path(args.data)
|
||||
train, valid, test = [
|
||||
create_dataset(data_path / f"{n}.jsonl", tokenizer) for n in names
|
||||
]
|
||||
if args.train and len(train) == 0:
|
||||
raise ValueError(
|
||||
"Training set not found or empty. Must provide training set for fine-tuning."
|
||||
)
|
||||
if args.train and len(valid) == 0:
|
||||
raise ValueError(
|
||||
"Validation set not found or empty. Must provide validation set for fine-tuning."
|
||||
)
|
||||
if args.test and len(test) == 0:
|
||||
raise ValueError(
|
||||
"Test set not found or empty. Must provide test set for evaluation."
|
||||
)
|
||||
return train, valid, test
|
Reference in New Issue
Block a user