mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
removing dpo and fixing some stuff for orpo
This commit is contained in:
@@ -4,70 +4,47 @@ from typing import Dict, List, Optional
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class DPODataset:
|
||||
"""
|
||||
A dataset for DPO (Direct Preference Optimization) training that handles
|
||||
prompt-chosen-rejected triplets with optional scores in the format:
|
||||
{"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
chosen_key: str = "chosen",
|
||||
rejected_key: str = "rejected",
|
||||
score_chosen_key: str = "score_chosen",
|
||||
score_rejected_key: str = "score_rejected",
|
||||
):
|
||||
class ORPODataset:
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
chosen_key: str = "chosen",
|
||||
rejected_key: str = "rejected",
|
||||
preference_score_key: str = "preference_score"
|
||||
):
|
||||
self._chosen_data = []
|
||||
self._rejected_data = []
|
||||
self._scores = []
|
||||
|
||||
|
||||
for d in data:
|
||||
# Process the text data
|
||||
chosen_text = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
],
|
||||
)
|
||||
rejected_text = tokenizer.apply_chat_template(
|
||||
[
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
],
|
||||
)
|
||||
|
||||
chosen_text = tokenizer.apply_chat_template([
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[chosen_key]},
|
||||
])
|
||||
rejected_text = tokenizer.apply_chat_template([
|
||||
{"role": "user", "content": d[prompt_key]},
|
||||
{"role": "assistant", "content": d[rejected_key]},
|
||||
])
|
||||
|
||||
self._chosen_data.append(chosen_text)
|
||||
self._rejected_data.append(rejected_text)
|
||||
|
||||
# Handle scores if they exist
|
||||
if score_chosen_key in d and score_rejected_key in d:
|
||||
chosen_score = float(d[score_chosen_key])
|
||||
rejected_score = float(d[score_rejected_key])
|
||||
|
||||
# Normalize scores to [0, 1] range
|
||||
score_diff = chosen_score - rejected_score
|
||||
max_diff = max(abs(score_diff), 1.0) # Avoid division by zero
|
||||
normalized_score = (score_diff / max_diff + 1) / 2
|
||||
|
||||
self._scores.append(normalized_score)
|
||||
|
||||
if preference_score_key in d:
|
||||
self._scores.append(float(d[preference_score_key]))
|
||||
else:
|
||||
# Default to binary preference (1.0) if no scores provided
|
||||
self._scores.append(1.0)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return {
|
||||
"chosen": self._chosen_data[idx],
|
||||
"rejected": self._rejected_data[idx],
|
||||
"preference_score": self._scores[idx]
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self._chosen_data)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
return {
|
||||
"chosen": self._chosen_data[idx],
|
||||
"rejected": self._rejected_data[idx],
|
||||
"preference_score": self._scores[idx]
|
||||
}
|
||||
|
||||
def __len__(self):
|
||||
return len(self._chosen_data)
|
||||
|
||||
|
||||
class Dataset:
|
||||
@@ -158,7 +135,7 @@ def create_dataset(
|
||||
|
||||
# Add DPO dataset support
|
||||
if "chosen" in sample and "rejected" in sample:
|
||||
return DPODataset(data, tokenizer)
|
||||
return ORPODataset(data, tokenizer)
|
||||
elif "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
|
||||
Reference in New Issue
Block a user