mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
starting fist training test run
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Optional, Tuple
|
||||
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
@@ -9,36 +9,30 @@ class GRPODataset:
|
||||
"""
|
||||
Dataset wrapper for GRPO training data.
|
||||
Each example should have a 'prompt' and 'answer' field.
|
||||
Returns data in (prompt, answer) tuple format required by GRPO trainer.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
prompt_key: str = "prompt",
|
||||
answer_key: str = "answer"
|
||||
):
|
||||
self._data = []
|
||||
|
||||
for item in data:
|
||||
# Tokenize prompt and answer
|
||||
prompt_tokens = tokenizer.encode(item[prompt_key])
|
||||
answer_tokens = tokenizer.encode(item[answer_key])
|
||||
# Get prompt and answer text
|
||||
prompt = str(item[prompt_key])
|
||||
answer = str(item[answer_key])
|
||||
|
||||
# Add EOS tokens if needed
|
||||
if prompt_tokens[-1] != tokenizer.eos_token_id:
|
||||
prompt_tokens.append(tokenizer.eos_token_id)
|
||||
if answer_tokens[-1] != tokenizer.eos_token_id:
|
||||
answer_tokens.append(tokenizer.eos_token_id)
|
||||
|
||||
self._data.append({
|
||||
'prompt': prompt_tokens,
|
||||
'answer': answer_tokens
|
||||
})
|
||||
# Store as (prompt, answer) tuple
|
||||
self._data.append((prompt, answer))
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, List[int]]:
|
||||
def __getitem__(self, idx: int) -> Tuple[str, str]:
|
||||
"""Returns a (prompt, answer) tuple for the given index."""
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the number of examples in the dataset."""
|
||||
return len(self._data)
|
||||
|
||||
|
||||
@@ -127,8 +121,11 @@ def create_dataset(
|
||||
prompt_feature = prompt_feature or "prompt"
|
||||
completion_feature = completion_feature or "completion"
|
||||
sample = data[0]
|
||||
|
||||
if "messages" in sample:
|
||||
return ChatDataset(data, tokenizer)
|
||||
elif "prompt" in sample and "answer" in sample:
|
||||
return GRPODataset(data, tokenizer, "prompt", "answer") # Use GRPO Dataset
|
||||
elif prompt_feature in sample and completion_feature in sample:
|
||||
return CompletionsDataset(data, tokenizer, prompt_feature, completion_feature)
|
||||
elif "text" in sample:
|
||||
|
||||
Reference in New Issue
Block a user