From a3ed632422d3b2bf3a0efbcd1624a2ab38360e6a Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 3 Feb 2025 09:13:17 +0100 Subject: [PATCH] dataset wrapper done --- llms/mlx_lm/tuner/datasets.py | 37 +++++++++++++++++++++++++++++++ llms/mlx_lm/tuner/grpo_trainer.py | 20 +++++++++++------ 2 files changed, 50 insertions(+), 7 deletions(-) diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index 377e7cae..8f185473 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -5,6 +5,43 @@ from typing import Dict, List, Optional from transformers import PreTrainedTokenizer +class GRPODataset: + """ + Dataset wrapper for GRPO training data. + Each example should have a 'prompt' and 'answer' field. + """ + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + prompt_key: str = "prompt", + answer_key: str = "answer" + ): + self._data = [] + + for item in data: + # Tokenize prompt and answer + prompt_tokens = tokenizer.encode(item[prompt_key]) + answer_tokens = tokenizer.encode(item[answer_key]) + + # Add EOS tokens if needed + if prompt_tokens[-1] != tokenizer.eos_token_id: + prompt_tokens.append(tokenizer.eos_token_id) + if answer_tokens[-1] != tokenizer.eos_token_id: + answer_tokens.append(tokenizer.eos_token_id) + + self._data.append({ + 'prompt': prompt_tokens, + 'answer': answer_tokens + }) + + def __getitem__(self, idx: int) -> Dict[str, List[int]]: + return self._data[idx] + + def __len__(self) -> int: + return len(self._data) + + class Dataset: """ Light-weight wrapper to hold a dataset. diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index ac735264..31edc0ec 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -130,13 +130,7 @@ def grpo_loss( model, tokenizer, prompts, - reward_funcs=[ - r1_accuracy_reward_func, - r1_int_reward_func, - r1_strict_format_reward_func, - r1_soft_format_reward_func, - r1_count_xml - ], + reward_funcs=None, beta=0.1, group_size=4, epsilon=1e-4, @@ -386,10 +380,18 @@ def evaluate_grpo( def train_grpo( model, + ref_model, tokenizer, optimizer, train_dataset, val_dataset, + reward_funcs = [ + r1_accuracy_reward_func, + r1_int_reward_func, + r1_strict_format_reward_func, + r1_soft_format_reward_func, + r1_count_xml + ], args: GRPOTrainingArgs = GRPOTrainingArgs(), loss: callable = grpo_loss, iterate_batches: callable = iterate_batches, @@ -452,6 +454,10 @@ def train_grpo( model=model, dataset=val_dataset, loss=loss, + + ref_model=model, + reward_funcs=reward_funcs, + tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.val_batches,