dataset wrapper done

This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 09:13:17 +01:00
parent d034ca369e
commit a3ed632422
2 changed files with 50 additions and 7 deletions

View File

@ -5,6 +5,43 @@ from typing import Dict, List, Optional
from transformers import PreTrainedTokenizer
class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
answer_key: str = "answer"
):
self._data = []
for item in data:
# Tokenize prompt and answer
prompt_tokens = tokenizer.encode(item[prompt_key])
answer_tokens = tokenizer.encode(item[answer_key])
# Add EOS tokens if needed
if prompt_tokens[-1] != tokenizer.eos_token_id:
prompt_tokens.append(tokenizer.eos_token_id)
if answer_tokens[-1] != tokenizer.eos_token_id:
answer_tokens.append(tokenizer.eos_token_id)
self._data.append({
'prompt': prompt_tokens,
'answer': answer_tokens
})
def __getitem__(self, idx: int) -> Dict[str, List[int]]:
return self._data[idx]
def __len__(self) -> int:
return len(self._data)
class Dataset:
"""
Light-weight wrapper to hold a dataset.

View File

@ -130,13 +130,7 @@ def grpo_loss(
model,
tokenizer,
prompts,
reward_funcs=[
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml
],
reward_funcs=None,
beta=0.1,
group_size=4,
epsilon=1e-4,
@ -386,10 +380,18 @@ def evaluate_grpo(
def train_grpo(
model,
ref_model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
reward_funcs = [
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml
],
args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss: callable = grpo_loss,
iterate_batches: callable = iterate_batches,
@ -452,6 +454,10 @@ def train_grpo(
model=model,
dataset=val_dataset,
loss=loss,
ref_model=model,
reward_funcs=reward_funcs,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,