From 54fcd8ed63d9fc80835f20cf478227b544c1de85 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Fri, 24 Jan 2025 18:11:56 +0100 Subject: [PATCH] update DPODataset and added in system field too --- llms/mlx_lm/tuner/datasets.py | 59 ++++++++++++++++++++++++----------- 1 file changed, 40 insertions(+), 19 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 78ce7a1a..0f23c2a4 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -9,7 +9,7 @@ class DPODataset: """ A dataset for DPO (Direct Preference Optimization) training that handles prompt-chosen-rejected triplets in the format: - {"prompt": ..., "chosen": ..., "rejected": ...} + {"system": ..., "prompt": ..., "chosen": ..., "rejected": ...} """ def __init__( @@ -19,26 +19,47 @@ class DPODataset: prompt_key: str = "prompt", chosen_key: str = "chosen", rejected_key: str = "rejected", + system_key: str = None ): - self._chosen_data = [ - tokenizer.apply_chat_template( - [ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[chosen_key]}, - ], - ) - for d in data - ] + + self._chosen_data = [] + self._rejected_data = [] + self._scores = [] - self._rejected_data = [ - tokenizer.apply_chat_template( - [ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[rejected_key]}, - ], - ) - for d in data - ] + for d in data: + if system_key and system_key in d: + chosen = tokenizer.apply_chat_template( + [ + {"role": "system", "content": d[system_key]}, + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[chosen_key]}, + ] + ) + + rejected = tokenizer.apply_chat_template( + [ + {"role": "system", "content": d[system_key]}, + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[rejected_key]}, + ], + ) + else: + chosen = tokenizer.apply_chat_template( + [ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[chosen_key]}, + ] + ) + + rejected = tokenizer.apply_chat_template( + [ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[rejected_key]}, + ], + ) + + self._chosen_data.append(chosen) + self._rejected_data.append(rejected) def __getitem__(self, idx: int): return {