better dataset handling

This commit is contained in:
Goekdeniz-Guelmez 2025-02-21 21:12:45 +01:00
parent de147187c1
commit 5704136791

View File

@ -6,10 +6,19 @@ from typing import Any, Dict, List, Optional, Union
from transformers import PreTrainedTokenizer
from typing import List, Dict, Union
from transformers import PreTrainedTokenizer
from typing import List, Dict, Union
from transformers import PreTrainedTokenizer
from typing import List, Dict, Union
from transformers import PreTrainedTokenizer
class ORPODataset:
def __init__(
self,
data: List[Dict[str, Union[str, Dict]]],
data: List[Dict[str, Union[str, Dict, List]]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
chosen_key: str = "chosen",
@ -22,28 +31,51 @@ class ORPODataset:
self._scores = []
for d in data:
# Get prompt content, preferring 'prompt' over 'question'
prompt_content = d.get(prompt_key, d.get("question", ""))
if system_key and system_key in d:
base_messages = [{"role": "system", "content": d[system_key]}]
chosen_messages = base_messages + [{"role": "user", "content": d[prompt_key]}]
chosen_messages = base_messages + [{"role": "user", "content": prompt_content}]
rejected_messages = base_messages + [{"role": "user", "content": prompt_content}]
# Handle chosen messages
if isinstance(d[chosen_key], str):
chosen_messages.append({"role": "assistant", "content": d[chosen_key]})
else:
chosen_messages.extend(d[chosen_key]["messages"])
rejected_messages = base_messages + [{"role": "user", "content": d[prompt_key]}]
elif isinstance(d[chosen_key], dict):
if "messages" in d[chosen_key]:
chosen_messages.extend(d[chosen_key]["messages"])
else:
chosen_messages.append({"role": "assistant", "content": d[chosen_key].get("content", "")})
elif isinstance(d[chosen_key], list):
chosen_messages.extend(d[chosen_key])
# Handle rejected messages
if isinstance(d[rejected_key], str):
rejected_messages.append({"role": "assistant", "content": d[rejected_key]})
else:
rejected_messages.extend(d[rejected_key]["messages"])
elif isinstance(d[rejected_key], dict):
if "messages" in d[rejected_key]:
rejected_messages.extend(d[rejected_key]["messages"])
else:
rejected_messages.append({"role": "assistant", "content": d[rejected_key].get("content", "")})
elif isinstance(d[rejected_key], list):
rejected_messages.extend(d[rejected_key])
chosen_text = tokenizer.apply_chat_template(chosen_messages)
rejected_text = tokenizer.apply_chat_template(rejected_messages)
else:
# Handle non-system message cases
chosen_content = self._extract_content(d[chosen_key])
rejected_content = self._extract_content(d[rejected_key])
chosen_text = tokenizer.apply_chat_template([
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[chosen_key] if isinstance(d[chosen_key], str) else d[chosen_key]["messages"][-1]["content"]},
{"role": "user", "content": prompt_content},
{"role": "assistant", "content": chosen_content},
])
rejected_text = tokenizer.apply_chat_template([
{"role": "user", "content": d[prompt_key]},
{"role": "assistant", "content": d[rejected_key] if isinstance(d[rejected_key], str) else d[rejected_key]["messages"][-1]["content"]},
{"role": "user", "content": prompt_content},
{"role": "assistant", "content": rejected_content},
])
self._chosen_data.append(chosen_text)
@ -53,6 +85,25 @@ class ORPODataset:
self._scores.append(float(d[preference_score_key]))
else:
self._scores.append(1.0)
def _extract_content(self, data):
"""Helper method to extract content from various data formats."""
if isinstance(data, str):
return data
elif isinstance(data, dict):
if "messages" in data:
last_message = data["messages"][-1]
return last_message.get("content", last_message.get("messages", ""))
return data.get("content", "")
elif isinstance(data, list):
last_message = data[-1]
if isinstance(last_message, dict):
if "content" in last_message:
return last_message["content"]
elif "messages" in last_message:
return last_message["messages"]
return last_message if isinstance(last_message, str) else ""
return ""
def __len__(self):
return len(self._chosen_data)
@ -213,7 +264,7 @@ def load_local_dataset(
with open(path, "r") as fid:
data = [json.loads(l) for l in fid]
return create_dataset(data, tokenizer, config)
return create_dataset(args, data, tokenizer, config)
names = ("train", "valid", "test")
train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]