starting fist training test run

This commit is contained in:
Goekdeniz-Guelmez
2025-02-03 10:08:28 +01:00
parent 41ff5364d7
commit 23d75cd7ad
3 changed files with 109 additions and 77 deletions

View File

@@ -1,6 +1,6 @@
import json
from pathlib import Path
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Tuple
from transformers import PreTrainedTokenizer
@@ -9,36 +9,30 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt, answer) tuple format required by GRPO trainer.
"""
def __init__(
self,
data: List[Dict[str, str]],
tokenizer: PreTrainedTokenizer,
prompt_key: str = "prompt",
prompt_key: str = "prompt",
answer_key: str = "answer"
):
self._data = []
for item in data:
# Tokenize prompt and answer
prompt_tokens = tokenizer.encode(item[prompt_key])
answer_tokens = tokenizer.encode(item[answer_key])
# Get prompt and answer text
prompt = str(item[prompt_key])
answer = str(item[answer_key])
# Add EOS tokens if needed
if prompt_tokens[-1] != tokenizer.eos_token_id:
prompt_tokens.append(tokenizer.eos_token_id)
if answer_tokens[-1] != tokenizer.eos_token_id:
answer_tokens.append(tokenizer.eos_token_id)
self._data.append({
'prompt': prompt_tokens,
'answer': answer_tokens
})
# Store as (prompt, answer) tuple
self._data.append((prompt, answer))
def __getitem__(self, idx: int) -> Dict[str, List[int]]:
def __getitem__(self, idx: int) -> Tuple[str, str]:
"""Returns a (prompt, answer) tuple for the given index."""
return self._data[idx]
def __len__(self) -> int:
"""Returns the number of examples in the dataset."""
return len(self._data)
@@ -127,8 +121,11 @@ def create_dataset(
prompt_feature = prompt_feature or "prompt"
completion_feature = completion_feature or "completion"
sample = data[0]
if "messages" in sample:
return ChatDataset(data, tokenizer)
elif "prompt" in sample and "answer" in sample:
return GRPODataset(data, tokenizer, "prompt", "answer") # Use GRPO Dataset
elif prompt_feature in sample and completion_feature in sample:
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
elif "text" in sample: