mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
first working prototype, will try training out at home
This commit is contained in:
@@ -9,7 +9,7 @@ class GRPODataset:
|
||||
"""
|
||||
Dataset wrapper for GRPO training data.
|
||||
Each example should have a 'prompt' and 'answer' field.
|
||||
Returns data in (prompt, answer) tuple format required by GRPO trainer.
|
||||
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
@@ -20,15 +20,14 @@ class GRPODataset:
|
||||
):
|
||||
self._data = []
|
||||
for item in data:
|
||||
# Get prompt and answer text
|
||||
prompt = str(item[prompt_key])
|
||||
answer = str(item[answer_key])
|
||||
|
||||
# Store as (prompt, answer) tuple
|
||||
self._data.append((prompt, answer))
|
||||
prompt_str = str(item[prompt_key])
|
||||
answer_str = str(item[answer_key])
|
||||
prompt_tokens = tokenizer.encode(prompt_str)
|
||||
answer_tokens = tokenizer.encode(answer_str)
|
||||
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
|
||||
|
||||
def __getitem__(self, idx: int) -> Tuple[str, str]:
|
||||
"""Returns a (prompt, answer) tuple for the given index."""
|
||||
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
|
||||
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
|
||||
Reference in New Issue
Block a user