first working prototype, will try training out at home

This commit is contained in:
Goekdeniz-Guelmez
2025-02-03 12:05:29 +01:00
parent 23d75cd7ad
commit 1d9e4802f0
2 changed files with 254 additions and 122 deletions

View File

@@ -9,7 +9,7 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
Returns data in (prompt, answer) tuple format required by GRPO trainer.
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
"""
def __init__(
self,
@@ -20,15 +20,14 @@ class GRPODataset:
):
self._data = []
for item in data:
# Get prompt and answer text
prompt = str(item[prompt_key])
answer = str(item[answer_key])
# Store as (prompt, answer) tuple
self._data.append((prompt, answer))
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
def __getitem__(self, idx: int) -> Tuple[str, str]:
"""Returns a (prompt, answer) tuple for the given index."""
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx]
def __len__(self) -> int: