initial commit

This commit is contained in:
Goekdeniz-Guelmez
2025-01-19 00:19:36 +01:00
parent 07f88f8057
commit 1ff788821c
3 changed files with 705 additions and 34 deletions

View File

@@ -5,6 +5,51 @@ from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class DPODataset:
"""
A dataset for DPO (Direct Preference Optimization) training that handles
prompt-chosen-rejected triplets in the format:
{"prompt": ..., "chosen": ..., "rejected": ...}
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
):
self._chosen_data = [
tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
],
)
for d in data
]
self._rejected_data = [
tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
for d in data
]
def __getitem__(self, idx: int):
return {
"chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx]
}
def __len__(self):
return len(self._chosen_data)
class Dataset:
"""
Light-weight wrapper to hold a dataset.
@@ -90,7 +135,11 @@ def create_dataset(
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0]
if "messages" in sample:
# Add DPO dataset support
if "chosen" in sample and "rejected" in sample:
return DPODataset(data, tokenizer)
elif "messages" in sample:
return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)