mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
fixes
This commit is contained in:
@@ -8,8 +8,8 @@ from transformers import PreTrainedTokenizer
|
||||
class DPODataset:
|
||||
"""
|
||||
A dataset for DPO (Direct Preference Optimization) training that handles
|
||||
prompt-chosen-rejected triplets in the format:
|
||||
{"prompt": ..., "chosen": ..., "rejected": ...}
|
||||
prompt-chosen-rejected triplets with optional scores in the format:
|
||||
{"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -19,31 +19,51 @@ class DPODataset:
|
||||
prompt_key: str = "prompt",
|
||||
chosen_key: str = "chosen",
|
||||
rejected_key: str = "rejected",
|
||||
score_chosen_key: str = "score_chosen",
|
||||
score_rejected_key: str = "score_rejected",
|
||||
):
|
||||
self._chosen_data = [
|
||||
tokenizer.apply_chat_template(
|
||||
self._chosen_data = []
|
||||
self._rejected_data = []
|
||||
self._scores = []
|
||||
|
||||
for d in data:
|
||||
# Process the text data
|
||||
chosen_text = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
self._rejected_data = [
|
||||
tokenizer.apply_chat_template(
|
||||
rejected_text = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
],
|
||||
)
|
||||
for d in data
|
||||
]
|
||||
|
||||
self._chosen_data.append(chosen_text)
|
||||
self._rejected_data.append(rejected_text)
|
||||
|
||||
# Handle scores if they exist
|
||||
if score_chosen_key in d and score_rejected_key in d:
|
||||
chosen_score = float(d[score_chosen_key])
|
||||
rejected_score = float(d[score_rejected_key])
|
||||
|
||||
# Normalize scores to [0, 1] range
|
||||
score_diff = chosen_score - rejected_score
|
||||
max_diff = max(abs(score_diff), 1.0) # Avoid division by zero
|
||||
normalized_score = (score_diff / max_diff + 1) / 2
|
||||
|
||||
self._scores.append(normalized_score)
|
||||
else:
|
||||
# Default to binary preference (1.0) if no scores provided
|
||||
self._scores.append(1.0)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return {
|
||||
"chosen": self._chosen_data[idx],
|
||||
"rejected": self._rejected_data[idx]
|
||||
"rejected": self._rejected_data[idx],
|
||||
"preference_score": self._scores[idx]
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
|
||||
Reference in New Issue
Block a user