nits and quality of life improvements

This commit is contained in:
Goekdeniz-Guelmez
2025-01-24 22:40:27 +01:00
parent 531c3345c6
commit 86b315fdf9
4 changed files with 43 additions and 120 deletions

View File

@@ -12,54 +12,25 @@ class DPODataset:
{"system": ..., "prompt": ..., "chosen": ..., "rejected": ...}
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
rejected_key: str = "rejected",
system_key: str = "system",
):
def __init__(self, data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt", chosen_key: str = "chosen",
rejected_key: str = "rejected", system_key: str = "system"):
self._chosen_data = []
self._rejected_data = []
self._scores = []
for d in data:
if system_key and system_key in d:
chosen = tokenizer.apply_chat_template(
[
{"role": "system", "content": d[system_key]},
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
]
)
messages = (
[{"role": "system", "content": d[system_key]}] if system_key and system_key in d else []
)
messages.append({"role": "user", "content": d[prompt_key]})
rejected = tokenizer.apply_chat_template(
[
{"role": "system", "content": d[system_key]},
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
else:
chosen = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
]
)
rejected = tokenizer.apply_chat_template(
[
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
self._chosen_data.append(chosen)
self._rejected_data.append(rejected)
# Apply template once for each response type
base_messages = messages.copy()
chosen_messages = base_messages + [{"role": "assistant", "content": d[chosen_key]}]
rejected_messages = base_messages + [{"role": "assistant", "content": d[rejected_key]}]
self._chosen_data.append(tokenizer.apply_chat_template(chosen_messages))
self._rejected_data.append(tokenizer.apply_chat_template(rejected_messages))
def __getitem__(self, idx: int):
return {