This commit is contained in:
Goekdeniz-Guelmez 2025-02-12 11:07:53 +01:00
parent c42e858d7e
commit e33d9d509b
3 changed files with 70 additions and 111 deletions

View File

@ -2,9 +2,8 @@ import itertools
import json
import types
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Tuple
from .utils import GRPOExample
from transformers import PreTrainedTokenizer
@ -12,7 +11,7 @@ class GRPODataset:
"""
Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field.
Returns data as GRPOExample instances.
Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
"""
def __init__(
self,
@ -23,40 +22,33 @@ class GRPODataset:
use_chat_template: bool = False,
use_prompt: bool = False
):
self._data: List[GRPOExample] = []
self._data = []
for item in data:
prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key])
if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template(
[
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str}
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str}
],
)
answer_tokens = tokenizer.encode(answer_str)
else:
if use_prompt:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str} Assistant: """)
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str} Assistant: """)
else:
prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
self._data.append(GRPOExample(
prompt_tokens=prompt_tokens,
answer_tokens=answer_tokens,
prompt_text=prompt_str,
answer_text=answer_str
))
def __getitem__(self, idx: int) -> GRPOExample:
"""Returns a GRPOExample instance."""
def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
"""Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
return self._data[idx]
def __len__(self) -> int:
@ -318,7 +310,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
else:
print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args, args.data, tokenizer, args)
train, valid, test = load_hf_dataset(args.data, tokenizer, args)
if args.train and len(train) == 0:
raise ValueError(

View File

@ -1,18 +1,18 @@
# Copyright © 2024 Apple Inc.
import time
from typing import List, Optional, Callable
from dataclasses import dataclass, field
from typing import List, Iterator, Optional
from pathlib import Path
import time
import re
from mlx.utils import tree_flatten
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from .utils import GRPOBatch, GRPOExample
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@ -37,6 +37,9 @@ class GRPOTrainingArgs(TrainingArgs):
)
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
@ -180,10 +183,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
def grpo_loss(
model: nn.Module,
ref_model: Optional[nn.Module],
model,
ref_model,
tokenizer,
batch=GRPOBatch,
batch,
reward_funcs=None,
beta=0.1,
group_size=4,
@ -191,18 +194,14 @@ def grpo_loss(
max_tokens=64,
temperature=1.0
):
prompts_tokens = batch.prompt_tokens
answers_tokens = batch.answer_tokens
prompts_text = batch.prompt_texts
answers_text = batch.answer_texts
batch_size = len(prompts_tokens)
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
# Generation logic remains the same
all_completions = []
all_completion_texts = []
for i in range(0, batch_size, batch_size):
batch_prompts = prompts_tokens[i:i+batch_size]
batch_prompts = prompt_tokens[i:i+batch_size]
for prompt in batch_prompts:
prompt_tensor = mx.array(prompt)
for _ in range(group_size):
@ -212,8 +211,6 @@ def grpo_loss(
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
# Clear completion tensors
mx.eval(completion_ids)
del completion_ids
except Exception as e:
@ -222,12 +219,11 @@ def grpo_loss(
mx.metal.clear_cache()
# Prepare inputs
expanded_answers = []
expanded_prompts = []
for i in range(batch_size):
expanded_answers.extend([answers_text[i]] * group_size)
expanded_prompts.extend([prompts_text[i]] * group_size)
expanded_answers.extend([answer_text[i]] * group_size)
expanded_prompts.extend([prompt_text[i]] * group_size)
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
@ -260,6 +256,8 @@ def grpo_loss(
ref_token_log_probs = token_log_probs
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
mx.eval(ref_token_log_probs)
mx.metal.clear_cache()
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
@ -275,7 +273,7 @@ def grpo_loss(
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Calculate rewards and advantages
# Rewards and advantages
rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(
@ -288,7 +286,7 @@ def grpo_loss(
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
# Reshape rewards and compute advantages following GRPO formula
# Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size)
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
@ -303,7 +301,7 @@ def grpo_loss(
# Compute policy ratio
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
# Compute per-token loss following GRPO formula
# Compute per-token loss
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
# Average over tokens
@ -339,58 +337,50 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics
def iterate_grpo_batches(
dataset: List[GRPOExample],
batch_size: int,
max_seq_length: int,
train: bool = False,
) -> Iterator[GRPOBatch]:
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
def length_key(i):
return len(dataset[i][0]) + len(dataset[i][1])
idx = sorted(range(len(dataset)), key=length_key)
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
# Get MLX distributed setup
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Sort by combined length for efficient batching
def length_key(example: GRPOExample) -> int:
return len(example.prompt_tokens) + len(example.answer_tokens)
sorted_dataset = sorted(dataset, key=length_key)
# Create batch indices
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
def batch_index_generator():
for i in range(0, len(idx) - batch_size + 1, batch_size):
yield idx[i : i + batch_size : step]
while True:
# Shuffle batch start indices
shuffled_starts = np.random.permutation(batch_starts)
indices = (
np.random.permutation(list(batch_index_generator())) if train
else batch_index_generator()
)
for start_idx in shuffled_starts:
# Account for distributed setup by taking every step-th example
batch_idx = list(range(start_idx, start_idx + batch_size, step))
current_batch = [sorted_dataset[j] for j in batch_idx]
for batch_idx in indices:
current_batch = [dataset[j] for j in batch_idx]
# Create batch using dataclass attributes
batch = GRPOBatch(
prompt_tokens=[ex.prompt_tokens for ex in current_batch],
answer_tokens=[ex.answer_tokens for ex in current_batch],
prompt_texts=[ex.prompt_text for ex in current_batch],
answer_texts=[ex.answer_text for ex in current_batch]
)
prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch]
answers_text = [item[3] for item in current_batch]
# Check sequence lengths
if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
if any(len(p) > max_seq_length for p in prompts_tokens):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
yield batch
yield prompts_tokens, answers_tokens, prompts_text, answers_text
if not train:
break
@ -407,7 +397,7 @@ def evaluate_grpo(
epsilon: float,
group_size: int,
max_seq_length,
reward_funcs = None,
reward_funcs: Optional[List[RewardFunctions]] = None,
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
@ -418,12 +408,10 @@ def evaluate_grpo(
"""
all_losses = 0
ntokens = 0
all_metrics = None # Initialize metrics dictionary
all_metrics = None
# Create iterator for batches
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
# Iterate through batches
for _, batch in zip(
index_iterator,
iterate_batches(
@ -432,7 +420,6 @@ def evaluate_grpo(
max_seq_length=max_seq_length,
),
):
# Calculate loss for current batch
losses, toks, metrics = loss_fn(
model=model,
tokenizer=tokenizer,
@ -444,18 +431,15 @@ def evaluate_grpo(
ref_model=ref_model
)
# Accumulate losses and tokens
all_losses += losses * toks
ntokens += toks
# Accumulate metrics
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
# Evaluate accumulated values
mx.eval(all_losses, ntokens)
# Aggregate across distributed workers
@ -475,9 +459,9 @@ def train_grpo(
ref_model: Optional[nn.Module],
tokenizer,
optimizer,
train_dataset: List[GRPOExample],
val_dataset: List[GRPOExample],
reward_funcs = [
train_dataset,
val_dataset,
reward_funcs: Optional[List[RewardFunctions]] = [
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,

View File

@ -2,8 +2,7 @@
import json
import types
from pathlib import Path
from typing import Dict, List
from dataclasses import dataclass
from typing import Dict
import mlx.core as mx
import mlx.nn as nn
@ -276,19 +275,3 @@ def print_trainable_parameters(model):
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)"
)
@dataclass
class GRPOExample:
"""Single example for GRPO training/inference."""
prompt_tokens: List[int]
answer_tokens: List[int]
prompt_text: str
answer_text: str
@dataclass
class GRPOBatch:
"""A batch of GRPO examples."""
prompt_tokens: List[List[int]]
answer_tokens: List[List[int]]
prompt_texts: List[str]
answer_texts: List[str]