This commit is contained in:
Goekdeniz-Guelmez 2025-03-01 12:47:13 +01:00
parent 8aeea10901
commit c119a7a4a5

View File

@ -2,48 +2,103 @@ import itertools
import json import json
import types import types
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Union
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
class DPODataset: class DPODataset:
""" def __init__(
A dataset for DPO (Direct Preference Optimization) training that handles self,
prompt-chosen-rejected triplets in the format: data: List[Dict[str, Union[str, Dict, List]]],
{"system": ..., "prompt": ..., "chosen": ..., "rejected": ...} tokenizer: PreTrainedTokenizer,
""" prompt_key: str = "prompt",
chosen_key: str = "chosen",
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, rejected_key: str = "rejected",
prompt_key: str = "prompt", chosen_key: str = "chosen", system_key: str = None
rejected_key: str = "rejected", system_key: str = "system"): ):
self._chosen_data = [] self._chosen_data = []
self._rejected_data = [] self._rejected_data = []
for d in data: for d in data:
messages = ( # Get prompt content, preferring 'prompt' over 'question'
[{"role": "system", "content": d[system_key]}] if system_key and system_key in d else [] prompt_content = d.get(prompt_key, d.get("question", ""))
)
messages.append({"role": "user", "content": d[prompt_key]})
# Apply template once for each response type if system_key and system_key in d:
base_messages = messages.copy() base_messages = [{"role": "system", "content": d[system_key]}]
chosen_messages = base_messages + [{"role": "assistant", "content": d[chosen_key]}] chosen_messages = base_messages + [{"role": "user", "content": prompt_content}]
rejected_messages = base_messages + [{"role": "assistant", "content": d[rejected_key]}] rejected_messages = base_messages + [{"role": "user", "content": prompt_content}]
# Handle chosen messages
if isinstance(d[chosen_key], str):
chosen_messages.append({"role": "assistant", "content": d[chosen_key]})
elif isinstance(d[chosen_key], dict):
if "messages" in d[chosen_key]:
chosen_messages.extend(d[chosen_key]["messages"])
else:
chosen_messages.append({"role": "assistant", "content": d[chosen_key].get("content", "")})
elif isinstance(d[chosen_key], list):
chosen_messages.extend(d[chosen_key])
# Handle rejected messages
if isinstance(d[rejected_key], str):
rejected_messages.append({"role": "assistant", "content": d[rejected_key]})
elif isinstance(d[rejected_key], dict):
if "messages" in d[rejected_key]:
rejected_messages.extend(d[rejected_key]["messages"])
else:
rejected_messages.append({"role": "assistant", "content": d[rejected_key].get("content", "")})
elif isinstance(d[rejected_key], list):
rejected_messages.extend(d[rejected_key])
chosen_text = tokenizer.apply_chat_template(chosen_messages)
rejected_text = tokenizer.apply_chat_template(rejected_messages)
self._chosen_data.append(tokenizer.apply_chat_template(chosen_messages)) else:
self._rejected_data.append(tokenizer.apply_chat_template(rejected_messages)) # Handle non-system message cases
chosen_content = self._extract_content(d[chosen_key])
rejected_content = self._extract_content(d[rejected_key])
chosen_text = tokenizer.apply_chat_template([
{"role": "user", "content": prompt_content},
{"role": "assistant", "content": chosen_content},
])
rejected_text = tokenizer.apply_chat_template([
{"role": "user", "content": prompt_content},
{"role": "assistant", "content": rejected_content},
])
self._chosen_data.append(chosen_text)
self._rejected_data.append(rejected_text)
def _extract_content(self, data):
"""Helper method to extract content from various data formats."""
if isinstance(data, str):
return data
elif isinstance(data, dict):
if "messages" in data:
last_message = data["messages"][-1]
return last_message.get("content", last_message.get("messages", ""))
return data.get("content", "")
elif isinstance(data, list):
last_message = data[-1]
if isinstance(last_message, dict):
if "content" in last_message:
return last_message["content"]
elif "messages" in last_message:
return last_message["messages"]
return last_message if isinstance(last_message, str) else ""
return ""
def __len__(self):
return len(self._chosen_data)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return { return {
"chosen": self._chosen_data[idx], "chosen": self._chosen_data[idx],
"rejected": self._rejected_data[idx] "rejected": self._rejected_data[idx]
} }
def __len__(self):
return len(self._chosen_data)
class Dataset: class Dataset:
""" """
Light-weight wrapper to hold a dataset. Light-weight wrapper to hold a dataset.