mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
initial commit
This commit is contained in:
@@ -5,6 +5,51 @@ from typing import Dict, List, Optional
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class DPODataset:
|
||||
"""
|
||||
A dataset for DPO (Direct Preference Optimization) training that handles
|
||||
prompt-chosen-rejected triplets in the format:
|
||||
{"prompt": ..., "chosen": ..., "rejected": ...}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
chosen_key: str = "chosen",
|
||||
rejected_key: str = "rejected",
|
||||
):
|
||||
self._chosen_data = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
self._rejected_data = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return {
|
||||
"chosen": self._chosen_data[idx],
|
||||
"rejected": self._rejected_data[idx]
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self._chosen_data)
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold a dataset.
|
||||
@@ -90,7 +135,11 @@ def create_dataset(
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
sample = data[0]
|
||||
if "messages" in sample:
|
||||
|
||||
# Add DPO dataset support
|
||||
if "chosen" in sample and "rejected" in sample:
|
||||
return DPODataset(data, tokenizer)
|
||||
elif "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
|
||||
Reference in New Issue
Block a user