removing dpo and fixing some stuff for orpo

This commit is contained in:
Goekdeniz-Guelmez
2025-01-24 16:09:22 +01:00
parent 0bb001121e
commit e3688293ed
4 changed files with 153 additions and 714 deletions

View File

@@ -4,70 +4,47 @@ from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class DPODataset:
"""
A dataset for DPO (Direct Preference Optimization) training that handles
prompt-chosen-rejected triplets with optional scores in the format:
{"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...}
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
score_chosen_key: str = "score_chosen",
score_rejected_key: str = "score_rejected",
):
class ORPODataset:
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
preference_score_key: str = "preference_score"
):
self._chosen_data = []
self._rejected_data = []
self._scores = []
for d in data:
# Process the text data
chosen_text = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
],
)
rejected_text = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
chosen_text = tokenizer.apply_chat_template([
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
])
rejected_text = tokenizer.apply_chat_template([
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
])
self._chosen_data.append(chosen_text)
self._rejected_data.append(rejected_text)
# Handle scores if they exist
if score_chosen_key in d and score_rejected_key in d:
chosen_score = float(d[score_chosen_key])
rejected_score = float(d[score_rejected_key])
# Normalize scores to [0, 1] range
score_diff = chosen_score - rejected_score
max_diff = max(abs(score_diff), 1.0) # Avoid division by zero
normalized_score = (score_diff / max_diff + 1) / 2
self._scores.append(normalized_score)
if preference_score_key in d:
self._scores.append(float(d[preference_score_key]))
else:
# Default to binary preference (1.0) if no scores provided
self._scores.append(1.0)
def __getitem__(self, idx: int):
return {
"chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx],
"preference_score": self._scores[idx]
}
def __len__(self):
return len(self._chosen_data)
def __getitem__(self, idx: int):
return {
"chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx],
"preference_score": self._scores[idx]
}
def __len__(self):
return len(self._chosen_data)
class Dataset:
@@ -158,7 +135,7 @@ def create_dataset(
# Add DPO dataset support
if "chosen" in sample and "rejected" in sample:
return DPODataset(data, tokenizer)
return ORPODataset(data, tokenizer)
elif "messages" in sample:
return ChatDataset(data, tokenizer)
elif prompt_feature in sample and completion_feature in sample: