This commit is contained in:
Goekdeniz-Guelmez
2025-01-19 13:45:33 +01:00
parent ea0d11cd2f
commit 363bde634e
5 changed files with 55 additions and 36 deletions

View File

@@ -8,8 +8,8 @@ from transformers import PreTrainedTokenizer
class DPODataset:
"""
A dataset for DPO (Direct Preference Optimization) training that handles
prompt-chosen-rejected triplets in the format:
{"prompt": ..., "chosen": ..., "rejected": ...}
prompt-chosen-rejected triplets with optional scores in the format:
{"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...}
"""
def __init__(
@@ -19,31 +19,51 @@ class DPODataset:
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
score_chosen_key: str = "score_chosen",
score_rejected_key: str = "score_rejected",
):
self._chosen_data = [
tokenizer.apply_chat_template(
self._chosen_data = []
self._rejected_data = []
self._scores = []
for d in data:
# Process the text data
chosen_text = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
],
)
for d in data
]
self._rejected_data = [
tokenizer.apply_chat_template(
rejected_text = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
for d in data
]
self._chosen_data.append(chosen_text)
self._rejected_data.append(rejected_text)
# Handle scores if they exist
if score_chosen_key in d and score_rejected_key in d:
chosen_score = float(d[score_chosen_key])
rejected_score = float(d[score_rejected_key])
# Normalize scores to [0, 1] range
score_diff = chosen_score - rejected_score
max_diff = max(abs(score_diff), 1.0) # Avoid division by zero
normalized_score = (score_diff / max_diff + 1) / 2
self._scores.append(normalized_score)
else:
# Default to binary preference (1.0) if no scores provided
self._scores.append(1.0)
def __getitem__(self, idx: int):
return {
"chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx]
"rejected": self._rejected_data[idx],
"preference_score": self._scores[idx]
}
def __len__(self):