# Copyright © 2024 Apple Inc.
import time
from dataclasses import dataclass, field
from pathlib import Path
import re
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
from mlx_lm import generate
@dataclass
class GRPOTrainingArgs(TrainingArgs):
group_size: int = field(
default=4,
metadata={"help": "Number of responses per prompt."},
)
beta: float = field(
default=0.1, metadata={"help": "KL penalty coefficient."}
)
epsilon: float = field(
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
)
reference_model_path: str = field(
default=None,
metadata={
"help": "Path to reference model weights. If None, uses the same model."
}
)
def r1_extract_xml_answer(text: str) -> str:
"""Extracts the answer from an XML formatted text string."""
try:
answer = text.split("")[-1]
answer = answer.split("")[0]
return answer.strip()
except:
print("[extract_xml_answer] Failed to extract answer from: ", text)
return ""
def r1_accuracy_reward_func(prompts, completions, answer, **kwargs) -> list[float]:
"""Calculates reward based on accuracy of extracted answers.
Args:
prompts: List of input prompts
completions: List of completion strings
answer: Expected answer or list of answers
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
q = prompts[0] if isinstance(prompts[0], str) else prompts[0][-1]['content']
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
def r1_int_reward_func(completions, **kwargs) -> list[float]:
"""Rewards numerical responses.
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
def r1_strict_format_reward_func(completions, **kwargs) -> list[float]:
"""Rewards completions with strict XML format.
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
pattern = r"^\n.*?\n\n\n.*?\n\n$"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_soft_format_reward_func(completions, **kwargs) -> list[float]:
"""Rewards completions with flexible XML format.
Args:
completions: List of completion strings
**kwargs: Additional arguments
Returns:
list[float]: Reward values for each completion
"""
pattern = r".*?\s*.*?"
matches = [re.match(pattern, r) for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(text: str) -> float:
"""Calculates score based on XML formatting.
Args:
text: Input text string
Returns:
float: Score based on XML tag presence and formatting
"""
count = 0.0
if text.count("\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
if text.count("\n\n") == 1:
count += 0.125
count -= len(text.split("\n\n")[-1])*0.001
if text.count("\n") == 1:
count += 0.125
count -= (len(text.split("\n")[-1]) - 1)*0.001
return count
def grpo_loss(
model,
tokenizer,
prompts,
reward_funcs=None,
beta=0.1,
group_size=4,
epsilon=1e-4,
ref_model=None
):
"""
Calculates the GRPO loss with support for multiple reward functions.
Args:
model: The model to optimize
tokenizer: Tokenizer for processing inputs
prompts: List of input prompts
reward_funcs: List of reward functions to use
beta: KL penalty coefficient
group_size: Number of completions per prompt
epsilon: Small constant for numerical stability
ref_model: Optional reference model for KL divergence
Returns:
tuple: (loss, total_sequence_length, metrics_dict)
"""
batch_size = len(prompts)
# Generate multiple completions for each prompt
all_completions = []
for prompt in prompts:
prompt_completions = []
for _ in range(group_size):
completion = generate(model, tokenizer, prompt)
prompt_completions.append(completion)
all_completions.extend(prompt_completions)
# Tokenize all prompts + completions
tokenized_inputs = tokenizer(
[p + c for p, c in zip(prompts * group_size, all_completions)],
return_tensors="np",
padding=True
)
inputs = mx.array(tokenized_inputs["input_ids"])
attention_mask = mx.array(tokenized_inputs["attention_mask"])
# Get lengths for proper masking
lengths = attention_mask.sum(axis=1)
# Get logits from current model
logits = model(inputs).astype(mx.float32)
# Calculate log probabilities
log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1)
# Prepare targets
targets = inputs[:, 1:]
# Gather actual token probabilities
token_log_probs = mx.take_along_axis(
log_probs,
targets.reshape(*targets.shape, 1),
axis=-1
).squeeze(-1)
# Get reference model log probabilities
if ref_model is not None:
ref_logits = ref_model(inputs).astype(mx.float32)
else:
ref_logits = model(inputs).astype(mx.float32)
ref_log_probs = mx.log_softmax(ref_logits[:, :-1, :], axis=-1)
ref_token_log_probs = mx.take_along_axis(
ref_log_probs,
targets.reshape(*targets.shape, 1),
axis=-1
).squeeze(-1)
# Compute KL divergence
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
# Calculate combined rewards from all reward functions
rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(all_completions))
rewards += func_rewards
# Normalize rewards if using multiple reward functions
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
# Compute grouped-wise rewards
grouped_rewards = rewards.reshape(batch_size, group_size)
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
std_grouped_rewards = mx.std(grouped_rewards, axis=1)
# Normalize rewards to compute advantages
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
# Create length mask for the shifted sequence
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Calculate policy gradient loss
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
per_token_loss = -(per_token_loss - beta * kl_div)
# Normalize loss properly per sequence
sequence_sums = (per_token_loss * length_mask).sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
# Collect metrics for each reward function separately
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
func_rewards = mx.array(reward_func(all_completions))
func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
metrics = {
'total_rewards_mean': mx.mean(rewards),
'total_rewards_std': mx.std(rewards),
'grouped_rewards_mean': mx.mean(grouped_rewards),
'grouped_rewards_std': mx.std(grouped_rewards),
'kl': mean_kl,
**reward_metrics
}
return loss, sequence_lengths.sum(), metrics
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Creates batches from prompt-answer pairs for GRPO training.
Args:
dataset: List of (prompt, answer) pairs
tokenizer: Tokenizer for processing inputs
batch_size: Size of each batch
max_seq_length: Maximum sequence length
train: Whether this is for training
Yields:
List of prompts for the current batch
"""
# Verify dataset is not empty and has correct format
if not dataset or not isinstance(dataset[0], (tuple, list)) or len(dataset[0]) != 2:
raise ValueError("Dataset must be a list of (prompt, answer) pairs")
# Sort by combined length of prompt + answer
idx = sorted(range(len(dataset)),
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
# Handle distributed training
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Create batch indices
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
# Shuffle batch indices if training
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
# Get current batch of prompt-answer pairs
current_batch = [dataset[j] for j in batch_idx[i]]
# Extract prompts and answers
prompts = [pair[0] for pair in current_batch]
answers = [pair[1] for pair in current_batch]
if any(len(p) > max_seq_length for p in prompts):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
# For GRPO, we only need to yield the prompts
# The answers will be used by the reward functions
yield prompts
if not train:
break
def evaluate_grpo(
model,
ref_model,
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
epsilon: float,
group_size: int,
max_seq_length,
reward_funcs = None,
loss: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
):
all_losses = 0
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
prompts = batch
losses, toks, metrics = loss(
model=model,
tokenizer=tokenizer,
prompts=prompts,
reward_funcs=reward_funcs,
beta=beta,
group_size=group_size,
epsilon=epsilon,
ref_model=ref_model
)
all_losses += losses * toks
ntokens += toks
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
mx.eval(all_losses, ntokens)
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item()
return avg_loss, ntokens, avg_metrics
def train_grpo(
model,
ref_model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
reward_funcs = [
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml
],
args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None,
):
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
# Forward and backward pass
(loss, toks, metrics), grad = loss_value_and_grad(model, *batch)
# All reduce the gradients if running in distributed mode
grad = average_gradients(grad)
# Model update
optimizer.update(model, grad)
return loss, toks, metrics
loss_value_and_grad = nn.value_and_grad(model, loss)
losses = 0
n_tokens = 0
steps = 0
trained_tokens = 0
accumulated_metrics = {
'rewards': 0,
'rewards_std': 0,
'grouped_rewards': 0,
'grouped_rewards_std': 0,
'kl': 0
}
for i in range(len(reward_funcs)):
accumulated_metrics[f'reward_func_{i}_mean'] = 0
accumulated_metrics[f'reward_func_{i}_std'] = 0
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_batches(
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_ntokens, val_metrics = evaluate_grpo(
model=model,
dataset=val_dataset,
loss=loss,
ref_model=ref_model,
reward_funcs=reward_funcs,
tokenizer=tokenizer,
group_size=args.group_size,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
epsilon=args.epsilon,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
if rank == 0:
val_metrics_str = (
f"Val loss {val_loss:.8f}, "
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
f"Val kl {val_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i in range(len(reward_funcs)):
val_metrics_str += (
f", Val reward_func_{i}_mean {val_metrics[f'reward_func_{i}_mean']:.3f}, "
f"Val reward_func_{i}_std {val_metrics[f'reward_func_{i}_std']:.3f}"
)
print(
f"Iter {it}: {val_metrics_str}, "
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
training_callback.on_val_loss_report({
"iteration": it,
"val_loss": val_loss,
**{f"val_{k}": v for k, v in val_metrics.items()},
"val_time": val_time,
})
start = time.perf_counter()
loss, toks, metrics = step(batch)
losses += loss
n_tokens += toks
steps += 1
for k, v in metrics.items():
accumulated_metrics[k] += v
mx.eval(state, losses, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
train_metrics_str = (
f"Train loss {train_loss:.8f}, "
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
f"KL {avg_metrics['kl']:.3f}"
)
# Add reward function specific metrics
for i in range(len(reward_funcs)):
train_metrics_str += (
f", Reward func {i} mean {avg_metrics[f'reward_func_{i}_mean']:.3f}, "
f"Reward func {i} std {avg_metrics[f'reward_func_{i}_std']:.3f}"
)
print(
f"Iter {it}: {train_metrics_str}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
training_callback.on_train_loss_report({
"iteration": it,
"train_loss": train_loss,
**{f"train_{k}": v for k, v in avg_metrics.items()},
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
})
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")