update DPODataset and added in system field too

This commit is contained in:
Goekdeniz-Guelmez 2025-01-24 18:11:56 +01:00
parent aefe4ba160
commit 54fcd8ed63

View File

@ -9,7 +9,7 @@ class DPODataset:
""" """
A dataset for DPO (Direct Preference Optimization) training that handles A dataset for DPO (Direct Preference Optimization) training that handles
prompt-chosen-rejected triplets in the format: prompt-chosen-rejected triplets in the format:
{"prompt": ..., "chosen": ..., "rejected": ...} {"system": ..., "prompt": ..., "chosen": ..., "rejected": ...}
""" """
def __init__( def __init__(
@ -19,26 +19,47 @@ class DPODataset:
prompt_key: str = "prompt", prompt_key: str = "prompt",
chosen_key: str = "chosen", chosen_key: str = "chosen",
rejected_key: str = "rejected", rejected_key: str = "rejected",
system_key: str = None
): ):
self._chosen_data = [
tokenizer.apply_chat_template( self._chosen_data = []
self._rejected_data = []
self._scores = []
for d in data:
if system_key and system_key in d:
chosen = tokenizer.apply_chat_template(
[
{"role": "system", "content": d[system_key]},
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]},
]
)
rejected = tokenizer.apply_chat_template(
[
{"role": "system", "content": d[system_key]},
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]},
],
)
else:
chosen = tokenizer.apply_chat_template(
[ [
{"role": "user", "content": d[prompt_key]}, {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key]}, {"role": "assistant", "content": d[chosen_key]},
],
)
for d in data
] ]
)
self._rejected_data = [ rejected = tokenizer.apply_chat_template(
tokenizer.apply_chat_template(
[ [
{"role": "user", "content": d[prompt_key]}, {"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key]}, {"role": "assistant", "content": d[rejected_key]},
], ],
) )
for d in data
] self._chosen_data.append(chosen)
self._rejected_data.append(rejected)
def __getitem__(self, idx: int): def __getitem__(self, idx: int):
return { return {