mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 20:25:22 +08:00
update DPODataset and added in system field too
This commit is contained in:
parent
aefe4ba160
commit
54fcd8ed63
@ -9,7 +9,7 @@ class DPODataset:
|
|||||||
"""
|
"""
|
||||||
A dataset for DPO (Direct Preference Optimization) training that handles
|
A dataset for DPO (Direct Preference Optimization) training that handles
|
||||||
prompt-chosen-rejected triplets in the format:
|
prompt-chosen-rejected triplets in the format:
|
||||||
{"prompt": ..., "chosen": ..., "rejected": ...}
|
{"system": ..., "prompt": ..., "chosen": ..., "rejected": ...}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@ -19,26 +19,47 @@ class DPODataset:
|
|||||||
prompt_key: str = "prompt",
|
prompt_key: str = "prompt",
|
||||||
chosen_key: str = "chosen",
|
chosen_key: str = "chosen",
|
||||||
rejected_key: str = "rejected",
|
rejected_key: str = "rejected",
|
||||||
|
system_key: str = None
|
||||||
):
|
):
|
||||||
self._chosen_data = [
|
|
||||||
tokenizer.apply_chat_template(
|
self._chosen_data = []
|
||||||
[
|
self._rejected_data = []
|
||||||
{"role": "user", "content": d[prompt_key]},
|
self._scores = []
|
||||||
{"role": "assistant", "content": d[chosen_key]},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
for d in data
|
|
||||||
]
|
|
||||||
|
|
||||||
self._rejected_data = [
|
for d in data:
|
||||||
tokenizer.apply_chat_template(
|
if system_key and system_key in d:
|
||||||
[
|
chosen = tokenizer.apply_chat_template(
|
||||||
{"role": "user", "content": d[prompt_key]},
|
[
|
||||||
{"role": "assistant", "content": d[rejected_key]},
|
{"role": "system", "content": d[system_key]},
|
||||||
],
|
{"role": "user", "content": d[prompt_key]},
|
||||||
)
|
{"role": "assistant", "content": d[chosen_key]},
|
||||||
for d in data
|
]
|
||||||
]
|
)
|
||||||
|
|
||||||
|
rejected = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": d[system_key]},
|
||||||
|
{"role": "user", "content": d[prompt_key]},
|
||||||
|
{"role": "assistant", "content": d[rejected_key]},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
chosen = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": d[prompt_key]},
|
||||||
|
{"role": "assistant", "content": d[chosen_key]},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
rejected = tokenizer.apply_chat_template(
|
||||||
|
[
|
||||||
|
{"role": "user", "content": d[prompt_key]},
|
||||||
|
{"role": "assistant", "content": d[rejected_key]},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
self._chosen_data.append(chosen)
|
||||||
|
self._rejected_data.append(rejected)
|
||||||
|
|
||||||
def __getitem__(self, idx: int):
|
def __getitem__(self, idx: int):
|
||||||
return {
|
return {
|
||||||
|
Loading…
Reference in New Issue
Block a user