mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 04:31:13 +08:00
update DPODataset and added in system field too
This commit is contained in:
parent
aefe4ba160
commit
54fcd8ed63
@ -9,7 +9,7 @@ class DPODataset:
|
||||
"""
|
||||
A dataset for DPO (Direct Preference Optimization) training that handles
|
||||
prompt-chosen-rejected triplets in the format:
|
||||
{"prompt": ..., "chosen": ..., "rejected": ...}
|
||||
{"system": ..., "prompt": ..., "chosen": ..., "rejected": ...}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@ -19,26 +19,47 @@ class DPODataset:
|
||||
prompt_key: str = "prompt",
|
||||
chosen_key: str = "chosen",
|
||||
rejected_key: str = "rejected",
|
||||
system_key: str = None
|
||||
):
|
||||
self._chosen_data = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
self._chosen_data = []
|
||||
self._rejected_data = []
|
||||
self._scores = []
|
||||
|
||||
self._rejected_data = [
|
||||
tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
for d in data:
|
||||
if system_key and system_key in d:
|
||||
chosen = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "system", "content": d[system_key]},
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
]
|
||||
)
|
||||
|
||||
rejected = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "system", "content": d[system_key]},
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
],
|
||||
)
|
||||
else:
|
||||
chosen = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
]
|
||||
)
|
||||
|
||||
rejected = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
],
|
||||
)
|
||||
|
||||
self._chosen_data.append(chosen)
|
||||
self._rejected_data.append(rejected)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return {
|
||||
|
Loading…
Reference in New Issue
Block a user