mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
dataset wrapper done
This commit is contained in:
parent
d034ca369e
commit
a3ed632422
@ -5,6 +5,43 @@ from typing import Dict, List, Optional
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
class GRPODataset:
|
||||
"""
|
||||
Dataset wrapper for GRPO training data.
|
||||
Each example should have a 'prompt' and 'answer' field.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
data: List[Dict[str, str]],
|
||||
tokenizer: PreTrainedTokenizer,
|
||||
prompt_key: str = "prompt",
|
||||
answer_key: str = "answer"
|
||||
):
|
||||
self._data = []
|
||||
|
||||
for item in data:
|
||||
# Tokenize prompt and answer
|
||||
prompt_tokens = tokenizer.encode(item[prompt_key])
|
||||
answer_tokens = tokenizer.encode(item[answer_key])
|
||||
|
||||
# Add EOS tokens if needed
|
||||
if prompt_tokens[-1] != tokenizer.eos_token_id:
|
||||
prompt_tokens.append(tokenizer.eos_token_id)
|
||||
if answer_tokens[-1] != tokenizer.eos_token_id:
|
||||
answer_tokens.append(tokenizer.eos_token_id)
|
||||
|
||||
self._data.append({
|
||||
'prompt': prompt_tokens,
|
||||
'answer': answer_tokens
|
||||
})
|
||||
|
||||
def __getitem__(self, idx: int) -> Dict[str, List[int]]:
|
||||
return self._data[idx]
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
|
||||
class Dataset:
|
||||
"""
|
||||
Light-weight wrapper to hold a dataset.
|
||||
|
@ -130,13 +130,7 @@ def grpo_loss(
|
||||
model,
|
||||
tokenizer,
|
||||
prompts,
|
||||
reward_funcs=[
|
||||
r1_accuracy_reward_func,
|
||||
r1_int_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
r1_soft_format_reward_func,
|
||||
r1_count_xml
|
||||
],
|
||||
reward_funcs=None,
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
epsilon=1e-4,
|
||||
@ -386,10 +380,18 @@ def evaluate_grpo(
|
||||
|
||||
def train_grpo(
|
||||
model,
|
||||
ref_model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
reward_funcs = [
|
||||
r1_accuracy_reward_func,
|
||||
r1_int_reward_func,
|
||||
r1_strict_format_reward_func,
|
||||
r1_soft_format_reward_func,
|
||||
r1_count_xml
|
||||
],
|
||||
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
||||
loss: callable = grpo_loss,
|
||||
iterate_batches: callable = iterate_batches,
|
||||
@ -452,6 +454,10 @@ def train_grpo(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
loss=loss,
|
||||
|
||||
ref_model=model,
|
||||
reward_funcs=reward_funcs,
|
||||
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
|
Loading…
Reference in New Issue
Block a user