diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 5a06d90f..59e9b232 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -6,10 +6,19 @@ from typing import Any, Dict, List, Optional, Union from transformers import PreTrainedTokenizer +from typing import List, Dict, Union +from transformers import PreTrainedTokenizer + +from typing import List, Dict, Union +from transformers import PreTrainedTokenizer + +from typing import List, Dict, Union +from transformers import PreTrainedTokenizer + class ORPODataset: def __init__( self, - data: List[Dict[str, Union[str, Dict]]], + data: List[Dict[str, Union[str, Dict, List]]], tokenizer: PreTrainedTokenizer, prompt_key: str = "prompt", chosen_key: str = "chosen", @@ -22,28 +31,51 @@ class ORPODataset: self._scores = [] for d in data: + # Get prompt content, preferring 'prompt' over 'question' + prompt_content = d.get(prompt_key, d.get("question", "")) + if system_key and system_key in d: base_messages = [{"role": "system", "content": d[system_key]}] - chosen_messages = base_messages + [{"role": "user", "content": d[prompt_key]}] + chosen_messages = base_messages + [{"role": "user", "content": prompt_content}] + rejected_messages = base_messages + [{"role": "user", "content": prompt_content}] + + # Handle chosen messages if isinstance(d[chosen_key], str): chosen_messages.append({"role": "assistant", "content": d[chosen_key]}) - else: - chosen_messages.extend(d[chosen_key]["messages"]) - rejected_messages = base_messages + [{"role": "user", "content": d[prompt_key]}] + elif isinstance(d[chosen_key], dict): + if "messages" in d[chosen_key]: + chosen_messages.extend(d[chosen_key]["messages"]) + else: + chosen_messages.append({"role": "assistant", "content": d[chosen_key].get("content", "")}) + elif isinstance(d[chosen_key], list): + chosen_messages.extend(d[chosen_key]) + + # Handle rejected messages if isinstance(d[rejected_key], str): rejected_messages.append({"role": "assistant", "content": d[rejected_key]}) - else: - rejected_messages.extend(d[rejected_key]["messages"]) + elif isinstance(d[rejected_key], dict): + if "messages" in d[rejected_key]: + rejected_messages.extend(d[rejected_key]["messages"]) + else: + rejected_messages.append({"role": "assistant", "content": d[rejected_key].get("content", "")}) + elif isinstance(d[rejected_key], list): + rejected_messages.extend(d[rejected_key]) + chosen_text = tokenizer.apply_chat_template(chosen_messages) rejected_text = tokenizer.apply_chat_template(rejected_messages) + else: + # Handle non-system message cases + chosen_content = self._extract_content(d[chosen_key]) + rejected_content = self._extract_content(d[rejected_key]) + chosen_text = tokenizer.apply_chat_template([ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[chosen_key] if isinstance(d[chosen_key], str) else d[chosen_key]["messages"][-1]["content"]}, + {"role": "user", "content": prompt_content}, + {"role": "assistant", "content": chosen_content}, ]) rejected_text = tokenizer.apply_chat_template([ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[rejected_key] if isinstance(d[rejected_key], str) else d[rejected_key]["messages"][-1]["content"]}, + {"role": "user", "content": prompt_content}, + {"role": "assistant", "content": rejected_content}, ]) self._chosen_data.append(chosen_text) @@ -53,6 +85,25 @@ class ORPODataset: self._scores.append(float(d[preference_score_key])) else: self._scores.append(1.0) + + def _extract_content(self, data): + """Helper method to extract content from various data formats.""" + if isinstance(data, str): + return data + elif isinstance(data, dict): + if "messages" in data: + last_message = data["messages"][-1] + return last_message.get("content", last_message.get("messages", "")) + return data.get("content", "") + elif isinstance(data, list): + last_message = data[-1] + if isinstance(last_message, dict): + if "content" in last_message: + return last_message["content"] + elif "messages" in last_message: + return last_message["messages"] + return last_message if isinstance(last_message, str) else "" + return "" def __len__(self): return len(self._chosen_data) @@ -213,7 +264,7 @@ def load_local_dataset( with open(path, "r") as fid: data = [json.loads(l) for l in fid] - return create_dataset(data, tokenizer, config) + return create_dataset(args, data, tokenizer, config) names = ("train", "valid", "test") train, valid, test = [load_subset(data_path / f"{n}.jsonl") for n in names]