This commit is contained in:
Goekdeniz-Guelmez 2025-02-12 11:07:53 +01:00
parent c42e858d7e
commit e33d9d509b
3 changed files with 70 additions and 111 deletions

View File

@ -2,9 +2,8 @@ import itertools
import json import json
import types import types
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple from typing import Any, Dict, List, Tuple
from .utils import GRPOExample
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
@ -12,7 +11,7 @@ class GRPODataset:
""" """
Dataset wrapper for GRPO training data. Dataset wrapper for GRPO training data.
Each example should have a 'prompt' and 'answer' field. Each example should have a 'prompt' and 'answer' field.
Returns data as GRPOExample instances. Returns data in (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple format.
""" """
def __init__( def __init__(
self, self,
@ -23,40 +22,33 @@ class GRPODataset:
use_chat_template: bool = False, use_chat_template: bool = False,
use_prompt: bool = False use_prompt: bool = False
): ):
self._data: List[GRPOExample] = [] self._data = []
for item in data: for item in data:
prompt_str = str(item[prompt_key]) prompt_str = str(item[prompt_key])
answer_str = str(item[answer_key]) answer_str = str(item[answer_key])
if use_chat_template: if use_chat_template:
prompt_tokens = tokenizer.apply_chat_template( prompt_tokens = tokenizer.apply_chat_template(
[ [
{'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. {'role': 'system', 'content': """A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""}, The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>."""},
{'role': 'user', 'content': prompt_str} {'role': 'user', 'content': prompt_str}
], ],
) )
answer_tokens = tokenizer.encode(answer_str) answer_tokens = tokenizer.encode(answer_str)
else: else:
if use_prompt: if use_prompt:
prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it. prompt_tokens = tokenizer.encode(f"""A conversation between User and Assistant. The user asks a question, and the Assistant solves it.
The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer. The assistantfirst thinks about the reasoning process in the mind and then provides the user with the answer.
The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.
User: {prompt_str} Assistant: """) User: {prompt_str} Assistant: """)
else: else:
prompt_tokens = tokenizer.encode(prompt_str) prompt_tokens = tokenizer.encode(prompt_str)
answer_tokens = tokenizer.encode(answer_str) answer_tokens = tokenizer.encode(answer_str)
self._data.append((prompt_tokens, answer_tokens, prompt_str, answer_str))
self._data.append(GRPOExample( def __getitem__(self, idx: int) -> Tuple[List[int], List[int], str, str]:
prompt_tokens=prompt_tokens, """Returns a (prompt_tokens, answer_tokens, prompt_str, answer_str) tuple."""
answer_tokens=answer_tokens,
prompt_text=prompt_str,
answer_text=answer_str
))
def __getitem__(self, idx: int) -> GRPOExample:
"""Returns a GRPOExample instance."""
return self._data[idx] return self._data[idx]
def __len__(self) -> int: def __len__(self) -> int:
@ -318,7 +310,7 @@ def load_dataset(args, tokenizer: PreTrainedTokenizer):
train, valid, test = load_local_dataset(args, data_path, tokenizer, args) train, valid, test = load_local_dataset(args, data_path, tokenizer, args)
else: else:
print(f"Loading Hugging Face dataset {args.data}.") print(f"Loading Hugging Face dataset {args.data}.")
train, valid, test = load_hf_dataset(args, args.data, tokenizer, args) train, valid, test = load_hf_dataset(args.data, tokenizer, args)
if args.train and len(train) == 0: if args.train and len(train) == 0:
raise ValueError( raise ValueError(

View File

@ -1,18 +1,18 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
import time from typing import List, Optional, Callable
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Iterator, Optional
from pathlib import Path from pathlib import Path
import time
import re import re
from mlx.utils import tree_flatten
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.utils import tree_flatten
from .utils import GRPOBatch, GRPOExample from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
@dataclass @dataclass
class GRPOTrainingArgs(TrainingArgs): class GRPOTrainingArgs(TrainingArgs):
@ -37,6 +37,9 @@ class GRPOTrainingArgs(TrainingArgs):
) )
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
def r1_extract_xml_answer(text: str) -> str: def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string.""" """Extracts the answer from an XML formatted text string."""
try: try:
@ -180,10 +183,10 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
def grpo_loss( def grpo_loss(
model: nn.Module, model,
ref_model: Optional[nn.Module], ref_model,
tokenizer, tokenizer,
batch=GRPOBatch, batch,
reward_funcs=None, reward_funcs=None,
beta=0.1, beta=0.1,
group_size=4, group_size=4,
@ -191,18 +194,14 @@ def grpo_loss(
max_tokens=64, max_tokens=64,
temperature=1.0 temperature=1.0
): ):
prompts_tokens = batch.prompt_tokens prompt_tokens, answer_tokens, prompt_text, answer_text = batch
answers_tokens = batch.answer_tokens batch_size = len(prompt_tokens)
prompts_text = batch.prompt_texts
answers_text = batch.answer_texts
batch_size = len(prompts_tokens)
# Generation logic remains the same
all_completions = [] all_completions = []
all_completion_texts = [] all_completion_texts = []
for i in range(0, batch_size, batch_size): for i in range(0, batch_size, batch_size):
batch_prompts = prompts_tokens[i:i+batch_size] batch_prompts = prompt_tokens[i:i+batch_size]
for prompt in batch_prompts: for prompt in batch_prompts:
prompt_tensor = mx.array(prompt) prompt_tensor = mx.array(prompt)
for _ in range(group_size): for _ in range(group_size):
@ -212,8 +211,6 @@ def grpo_loss(
completion_text = tokenizer.decode(completion_ids.tolist()) completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids) all_completions.append(completion_ids)
all_completion_texts.append(completion_text) all_completion_texts.append(completion_text)
# Clear completion tensors
mx.eval(completion_ids) mx.eval(completion_ids)
del completion_ids del completion_ids
except Exception as e: except Exception as e:
@ -222,12 +219,11 @@ def grpo_loss(
mx.metal.clear_cache() mx.metal.clear_cache()
# Prepare inputs
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []
for i in range(batch_size): for i in range(batch_size):
expanded_answers.extend([answers_text[i]] * group_size) expanded_answers.extend([answer_text[i]] * group_size)
expanded_prompts.extend([prompts_text[i]] * group_size) expanded_prompts.extend([prompt_text[i]] * group_size)
max_length = max(ids.shape[0] for ids in all_completions) max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = [] padded_completions = []
@ -260,6 +256,8 @@ def grpo_loss(
ref_token_log_probs = token_log_probs ref_token_log_probs = token_log_probs
else: else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
mx.eval(ref_token_log_probs)
mx.metal.clear_cache()
max_len = max(x.shape[0] for x in token_log_probs) max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = [] padded_log_probs = []
@ -275,7 +273,7 @@ def grpo_loss(
token_log_probs = mx.stack(padded_log_probs) token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs) ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Calculate rewards and advantages # Rewards and advantages
rewards = mx.zeros((len(all_completions),)) rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs: for reward_func in reward_funcs:
func_rewards = mx.array(reward_func( func_rewards = mx.array(reward_func(
@ -288,7 +286,7 @@ def grpo_loss(
if len(reward_funcs) > 1: if len(reward_funcs) > 1:
rewards /= len(reward_funcs) rewards /= len(reward_funcs)
# Reshape rewards and compute advantages following GRPO formula # Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size) rewards_reshaped = rewards.reshape(batch_size, group_size)
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
@ -303,7 +301,7 @@ def grpo_loss(
# Compute policy ratio # Compute policy ratio
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))) policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
# Compute per-token loss following GRPO formula # Compute per-token loss
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask) per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
# Average over tokens # Average over tokens
@ -339,58 +337,50 @@ def grpo_loss(
return loss, sequence_lengths.sum(), metrics return loss, sequence_lengths.sum(), metrics
def iterate_grpo_batches( def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
dataset: List[GRPOExample], if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
batch_size: int, raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
max_seq_length: int,
train: bool = False, def length_key(i):
) -> Iterator[GRPOBatch]: return len(dataset[i][0]) + len(dataset[i][1])
idx = sorted(range(len(dataset)), key=length_key)
if len(dataset) < batch_size: if len(dataset) < batch_size:
raise ValueError( raise ValueError(
f"Dataset must have at least batch_size={batch_size} " f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}." f"examples but only has {len(dataset)}."
) )
# Get MLX distributed setup
step = mx.distributed.init().size() step = mx.distributed.init().size()
if batch_size % step != 0: if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers") raise ValueError("The batch size must be divisible by the number of workers")
# Sort by combined length for efficient batching def batch_index_generator():
def length_key(example: GRPOExample) -> int: for i in range(0, len(idx) - batch_size + 1, batch_size):
return len(example.prompt_tokens) + len(example.answer_tokens) yield idx[i : i + batch_size : step]
sorted_dataset = sorted(dataset, key=length_key)
# Create batch indices
num_complete_batches = (len(dataset) - batch_size + 1) // batch_size
batch_starts = range(0, num_complete_batches * batch_size, batch_size)
while True: while True:
# Shuffle batch start indices indices = (
shuffled_starts = np.random.permutation(batch_starts) np.random.permutation(list(batch_index_generator())) if train
else batch_index_generator()
)
for start_idx in shuffled_starts: for batch_idx in indices:
# Account for distributed setup by taking every step-th example current_batch = [dataset[j] for j in batch_idx]
batch_idx = list(range(start_idx, start_idx + batch_size, step))
current_batch = [sorted_dataset[j] for j in batch_idx]
# Create batch using dataclass attributes prompts_tokens = [item[0] for item in current_batch]
batch = GRPOBatch( answers_tokens = [item[1] for item in current_batch]
prompt_tokens=[ex.prompt_tokens for ex in current_batch], prompts_text = [item[2] for item in current_batch]
answer_tokens=[ex.answer_tokens for ex in current_batch], answers_text = [item[3] for item in current_batch]
prompt_texts=[ex.prompt_text for ex in current_batch],
answer_texts=[ex.answer_text for ex in current_batch]
)
# Check sequence lengths if any(len(p) > max_seq_length for p in prompts_tokens):
if any(len(tokens) > max_seq_length for tokens in batch.prompt_tokens):
print( print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated." "Long prompts will be truncated."
) )
yield batch yield prompts_tokens, answers_tokens, prompts_text, answers_text
if not train: if not train:
break break
@ -407,7 +397,7 @@ def evaluate_grpo(
epsilon: float, epsilon: float,
group_size: int, group_size: int,
max_seq_length, max_seq_length,
reward_funcs = None, reward_funcs: Optional[List[RewardFunctions]] = None,
loss_fn: callable = grpo_loss, loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches iterate_batches: callable = iterate_grpo_batches
): ):
@ -418,12 +408,10 @@ def evaluate_grpo(
""" """
all_losses = 0 all_losses = 0
ntokens = 0 ntokens = 0
all_metrics = None # Initialize metrics dictionary all_metrics = None
# Create iterator for batches
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
# Iterate through batches
for _, batch in zip( for _, batch in zip(
index_iterator, index_iterator,
iterate_batches( iterate_batches(
@ -432,7 +420,6 @@ def evaluate_grpo(
max_seq_length=max_seq_length, max_seq_length=max_seq_length,
), ),
): ):
# Calculate loss for current batch
losses, toks, metrics = loss_fn( losses, toks, metrics = loss_fn(
model=model, model=model,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -444,18 +431,15 @@ def evaluate_grpo(
ref_model=ref_model ref_model=ref_model
) )
# Accumulate losses and tokens
all_losses += losses * toks all_losses += losses * toks
ntokens += toks ntokens += toks
# Accumulate metrics
if all_metrics is None: if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()} all_metrics = {k: v * toks for k, v in metrics.items()}
else: else:
for k, v in metrics.items(): for k, v in metrics.items():
all_metrics[k] += v * toks all_metrics[k] += v * toks
# Evaluate accumulated values
mx.eval(all_losses, ntokens) mx.eval(all_losses, ntokens)
# Aggregate across distributed workers # Aggregate across distributed workers
@ -475,9 +459,9 @@ def train_grpo(
ref_model: Optional[nn.Module], ref_model: Optional[nn.Module],
tokenizer, tokenizer,
optimizer, optimizer,
train_dataset: List[GRPOExample], train_dataset,
val_dataset: List[GRPOExample], val_dataset,
reward_funcs = [ reward_funcs: Optional[List[RewardFunctions]] = [
r1_accuracy_reward_func, r1_accuracy_reward_func,
r1_int_reward_func, r1_int_reward_func,
r1_strict_format_reward_func, r1_strict_format_reward_func,

View File

@ -2,8 +2,7 @@
import json import json
import types import types
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict
from dataclasses import dataclass
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
@ -276,19 +275,3 @@ def print_trainable_parameters(model):
f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% "
f"({trainable_p:.3f}M/{total_p:.3f}M)" f"({trainable_p:.3f}M/{total_p:.3f}M)"
) )
@dataclass
class GRPOExample:
"""Single example for GRPO training/inference."""
prompt_tokens: List[int]
answer_tokens: List[int]
prompt_text: str
answer_text: str
@dataclass
class GRPOBatch:
"""A batch of GRPO examples."""
prompt_tokens: List[List[int]]
answer_tokens: List[List[int]]
prompt_texts: List[str]
answer_texts: List[str]