This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 12:59:41 +01:00
parent c817743333
commit 3dfb21267b
2 changed files with 432 additions and 297 deletions

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Callable
import re import re
from typing import Callable, List, Optional
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]] RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
@ -14,19 +13,30 @@ def r1_extract_xml_answer(text: str) -> str:
print("r1_extract_xml_answer returned empty string") print("r1_extract_xml_answer returned empty string")
return "" return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_int_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions] extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_accuracy_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions or not answer: if not completions or not answer:
return [0.0] * len(prompts) return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions] extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] return [
2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)
]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_soft_format_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
@ -41,9 +51,13 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
answer_start = completion.find("<answer>") answer_start = completion.find("<answer>")
answer_end = completion.find("</answer>") answer_end = completion.find("</answer>")
if (reason_start != -1 and reason_end != -1 and if (
answer_start != -1 and answer_end != -1 and reason_start != -1
reason_start < reason_end < answer_start < answer_end): and reason_end != -1
and answer_start != -1
and answer_end != -1
and reason_start < reason_end < answer_start < answer_end
):
reason_content = completion[reason_start + 13 : reason_end].strip() reason_content = completion[reason_start + 13 : reason_end].strip()
answer_content = completion[answer_start + 8 : answer_end].strip() answer_content = completion[answer_start + 8 : answer_end].strip()
if reason_content and answer_content: if reason_content and answer_content:
@ -52,14 +66,20 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
scores.append(0.0) scores.append(0.0)
return scores return scores
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_strict_format_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
pattern = r"<think> .*? </think><answer> .*? </answer>" pattern = r"<think> .*? </think><answer> .*? </answer>"
matches = [bool(re.search(pattern, r)) if r else False for r in completions] matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches] return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_count_xml(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions: if not completions:
return [0.0] * len(prompts) return [0.0] * len(prompts)
scores = [] scores = []

View File

@ -1,19 +1,28 @@
# Copyright © 2024 Apple Inc. # Copyright © 2024 Apple Inc.
from typing import List, Optional, Tuple, Generator, Callable, Any import time
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
import time from typing import Any, Callable, Generator, List, Optional, Tuple
from mlx.utils import tree_flatten
import mlx.core as mx import mlx.core as mx
import mlx.nn as nn import mlx.nn as nn
import numpy as np import numpy as np
from mlx.utils import tree_flatten
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
from ..models import cache from ..models import cache
from ..utils import generation_stream
from .grpo_reward_functions import (
RewardFunctions,
r1_accuracy_reward_func,
r1_count_xml,
r1_extract_xml_answer,
r1_int_reward_func,
r1_soft_format_reward_func,
r1_strict_format_reward_func,
)
from .trainer import TrainingArgs, TrainingCallback, average_gradients, grad_checkpoint
@dataclass @dataclass
class GRPOTrainingArgs(TrainingArgs): class GRPOTrainingArgs(TrainingArgs):
@ -21,9 +30,7 @@ class GRPOTrainingArgs(TrainingArgs):
default=4, default=4,
metadata={"help": "Number of responses per prompt."}, metadata={"help": "Number of responses per prompt."},
) )
beta: float = field( beta: float = field(default=0.1, metadata={"help": "KL penalty coefficient."})
default=0.1, metadata={"help": "KL penalty coefficient."}
)
epsilon: float = field( epsilon: float = field(
default=1e-4, metadata={"help": "The Epsilon for numerical stability."} default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
) )
@ -34,40 +41,142 @@ class GRPOTrainingArgs(TrainingArgs):
default=None, default=None,
metadata={ metadata={
"help": "Path to reference model weights. If None, uses the same model." "help": "Path to reference model weights. If None, uses the same model."
} },
) )
temperature: float = field( temperature: float = field(
default=1.0, default=1.0,
metadata={ metadata={
"help": "Temperature for sampling. The higher the temperature, the more random the completions." "help": "Temperature for sampling. The higher the temperature, the more random the completions."
} },
) )
reward_weights: Optional[List[float]] = field( reward_weights: Optional[List[float]] = field(
default=None, default=None,
metadata={ metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`." "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
} },
) )
def generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
prompt_progress_callback: Optional[Callable[int, int]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processors:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
return y, logprobs.squeeze(0)
with mx.stream(generation_stream):
total_prompt_tokens = y.size
prompt_processed_tokens = 0
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y)
mx.eval(y, logprobs)
n = 0
while True:
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
y, logprobs = next_y, next_logprobs
n += 1
def generate_grpo( def generate_grpo(
model: nn.Module, model: nn.Module,
prompts, prompts,
max_tokens, max_tokens,
tokenizer, tokenizer,
group_size, group_size,
is_training=False,
end_token: str = "</answer>", end_token: str = "</answer>",
temperature: float = 0.8, temperature: float = 0.8,
batch_size: int = 1 batch_size: int = 1,
): ):
# Store original training state
was_training = model.training
# Set model to eval mode for generation
model.eval()
try: try:
import time
start_time = time.time()
if len(prompts.shape) == 1: if len(prompts.shape) == 1:
prompts = prompts[None, :] prompts = prompts[None, :]
if prompts.shape[1] == 0: if prompts.shape[1] == 0:
@ -79,113 +188,84 @@ def generate_grpo(
results = [] results = []
mx.eval(expanded_prompts) mx.eval(expanded_prompts)
print(f"Setup time: {time.time() - start_time:.2f}s")
print(f"Generating {total_samples} samples with max_tokens={max_tokens}")
total_tokens_generated = 0
generation_start_time = time.time()
# Process in batches # Process in batches
for batch_start in range(0, total_samples, batch_size): for batch_start in range(0, total_samples, batch_size):
batch_end = min(batch_start + batch_size, total_samples) batch_end = min(batch_start + batch_size, total_samples)
batch_time = time.time()
if is_training: print(
# Training-specific generation logic f"Starting batch {batch_start//batch_size + 1}/{(total_samples + batch_size - 1)//batch_size}: samples {batch_start}-{batch_end-1}"
batch_inputs = expanded_prompts[batch_start:batch_end]
batch_tokens = [[] for _ in range(batch_end - batch_start)]
prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
# Initial forward pass
for i, prompt in enumerate(batch_inputs):
logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
logits_temp = logits / temperature
next_token = mx.random.categorical(logits_temp)
token = next_token.item()
batch_tokens[i].append(token)
del logits, logits_temp, next_token
mx.eval([tokens[-1] for tokens in batch_tokens])
mx.metal.clear_cache()
active_indices = [i for i in range(len(batch_tokens)) if batch_tokens[i][-1] != tokenizer.eos_token_id]
# Generate remaining tokens
for _ in range(max_tokens - 1):
if not active_indices:
break
next_active = []
for idx in active_indices:
current_input = mx.array([batch_tokens[idx][-1]])
logits = model(current_input[None], cache=prompt_caches[idx])[:, -1]
logits_temp = logits / temperature
next_token = mx.random.categorical(logits_temp)
token = next_token.item()
batch_tokens[idx].append(token)
# Check for end conditions
is_end = False
if len(batch_tokens[idx]) >= len(end_sequence):
test_sequence = batch_tokens[idx][-len(end_sequence):]
is_end = mx.array_equal(mx.array(test_sequence), end_sequence)
if not (is_end or token == tokenizer.eos_token_id):
next_active.append(idx)
del logits, logits_temp, next_token, current_input
mx.eval([tokens[-1] for tokens in batch_tokens])
mx.metal.clear_cache()
active_indices = next_active
# Clean up caches
for pc in prompt_caches:
del pc
# Process results
for tokens in batch_tokens:
if tokens:
# Truncate at end token if present
for i in range(len(tokens) - len(end_sequence) + 1):
if mx.array_equal(
mx.array(tokens[i:i+len(end_sequence)]),
end_sequence
):
tokens = tokens[:i+len(end_sequence)]
break
if tokens and tokens[-1] == tokenizer.eos_token_id:
tokens = tokens[:-1]
if tokens:
results.append(mx.array(tokens))
del batch_inputs, batch_tokens, prompt_caches
mx.metal.clear_cache()
else:
# Non-training mode with batched processing
for idx in range(batch_start, batch_end):
current_tokens = []
generator = generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.random.categorical(x / temperature)
) )
for token, _ in generator: # Custom sampler function that handles temperature
test_sequence = current_tokens + [token] def temp_sampler(logits):
if (len(test_sequence) >= len(end_sequence) and return mx.random.categorical(logits / temperature)
mx.array_equal(
mx.array(test_sequence[-len(end_sequence):]), # Batched processing
end_sequence for idx in range(batch_start, batch_end):
)): sample_start_time = time.time()
current_tokens = []
prompt_cache = cache.make_prompt_cache(model)
# The generate_step function yields one token at a time
# We'll collect tokens until we hit max_tokens or a stopping condition
for i, (token, _) in enumerate(
generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens, # This is the maximum number of steps
sampler=temp_sampler,
prompt_cache=prompt_cache,
)
):
print(token)
current_tokens.append(token) current_tokens.append(token)
# Check for end token
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
mx.array(current_tokens[-len(end_sequence) :]), end_sequence
):
break break
# Check for EOS token
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
current_tokens.append(token)
# Check if we've reached the maximum number of tokens
if i >= max_tokens - 1:
break
if current_tokens: if current_tokens:
results.append(mx.array(current_tokens)) results.append(mx.array(current_tokens))
total_tokens_generated += len(current_tokens)
sample_time = time.time() - sample_start_time
tokens_per_second = (
len(current_tokens) / sample_time if sample_time > 0 else 0
)
print(
f" Sample {idx}: Generated {len(current_tokens)} tokens in {sample_time:.2f}s ({tokens_per_second:.2f} tokens/sec)"
)
batch_time = time.time() - batch_time
print(f"Batch completed in {batch_time:.2f}s")
mx.metal.clear_cache() mx.metal.clear_cache()
generation_time = time.time() - generation_start_time
avg_tokens_per_second = (
total_tokens_generated / generation_time if generation_time > 0 else 0
)
print(
f"Generation complete: {total_tokens_generated} tokens in {generation_time:.2f}s"
)
print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec")
mx.eval(results) mx.eval(results)
return results return results
@ -193,10 +273,6 @@ def generate_grpo(
print(f"Generation error: {str(e)}") print(f"Generation error: {str(e)}")
return None return None
finally:
# Don't restore training mode - let the caller handle it
pass
def get_per_token_logps(model: nn.Module, inputs, lengths): def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16) logits = model(inputs).astype(mx.float16)
@ -209,8 +285,7 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
seq_targets = targets[i, :seq_len] seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1) log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis( token_log_probs = mx.take_along_axis(
log_probs, log_probs, seq_targets.reshape(seq_len, 1), axis=-1
seq_targets.reshape(seq_len, 1), axis=-1
).squeeze(-1) ).squeeze(-1)
per_token_logps.append(token_log_probs) per_token_logps.append(token_log_probs)
mx.eval(logits) mx.eval(logits)
@ -230,7 +305,7 @@ def grpo_loss(
temperature: float = 0.8, temperature: float = 0.8,
reward_weights: Optional[List[float]] = None, reward_weights: Optional[List[float]] = None,
is_validation: bool = False, is_validation: bool = False,
batch_size: int = 1 batch_size: int = 1,
): ):
prompt_tokens, _, prompt_text, answer_text = batch prompt_tokens, _, prompt_text, answer_text = batch
total_samples = len(prompt_tokens) total_samples = len(prompt_tokens)
@ -239,6 +314,12 @@ def grpo_loss(
all_completion_texts = [] all_completion_texts = []
batch_indices = [] # Keep track of which batch each completion belongs to batch_indices = [] # Keep track of which batch each completion belongs to
# Store original training state
was_training = model.training
# Set model to eval mode for generation
model.eval()
# Process in smaller batches # Process in smaller batches
for i in range(0, total_samples, batch_size): for i in range(0, total_samples, batch_size):
# Get actual batch size for this iteration (might be smaller for the last batch) # Get actual batch size for this iteration (might be smaller for the last batch)
@ -257,7 +338,6 @@ def grpo_loss(
prompt_tensor = mx.array(padded_prompts) prompt_tensor = mx.array(padded_prompts)
try: try:
if is_validation:
completions = generate_grpo( completions = generate_grpo(
model, model,
prompt_tensor, prompt_tensor,
@ -265,26 +345,16 @@ def grpo_loss(
tokenizer, tokenizer,
group_size, group_size,
temperature=temperature, temperature=temperature,
batch_size=current_batch_size batch_size=current_batch_size,
)
model.train()
else:
completions = generate_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer,
group_size,
is_training=True,
temperature=temperature,
batch_size=current_batch_size
) )
if completions is not None: if completions is not None:
for j, completion_ids in enumerate(completions): for j, completion_ids in enumerate(completions):
# Calculate which prompt this completion belongs to # Calculate which prompt this completion belongs to
prompt_idx = i + (j // group_size) prompt_idx = i + (j // group_size)
if prompt_idx < total_samples: # Make sure we don't go out of bounds if (
prompt_idx < total_samples
): # Make sure we don't go out of bounds
batch_indices.append(prompt_idx) batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist()) completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids) all_completions.append(completion_ids)
@ -294,12 +364,19 @@ def grpo_loss(
print(f"Generation error: {e}") print(f"Generation error: {e}")
continue continue
# Restore original training state if we're not in validation mode
if not is_validation and was_training:
model.train()
mx.metal.clear_cache() mx.metal.clear_cache()
# If we didn't generate any completions, return early # If we didn't generate any completions, return early
if not all_completions: if not all_completions:
raise ValueError("No completions were generated. Please check your model and inputs.") raise ValueError(
"No completions were generated. Please check your model and inputs."
)
# The rest of the function remains the same
# Create expanded prompts and answers based on actual generated completions # Create expanded prompts and answers based on actual generated completions
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []
@ -341,7 +418,9 @@ def grpo_loss(
if padding_length > 0: if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype) padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
padded_ids = mx.concatenate([completion_ids, padding]) padded_ids = mx.concatenate([completion_ids, padding])
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)]) mask = mx.concatenate(
[mx.ones_like(completion_ids), mx.zeros_like(padding)]
)
else: else:
padded_ids = completion_ids padded_ids = completion_ids
mask = mx.ones_like(completion_ids) mask = mx.ones_like(completion_ids)
@ -381,11 +460,13 @@ def grpo_loss(
# Collect rewards from each function separately # Collect rewards from each function separately
for reward_func in reward_funcs: for reward_func in reward_funcs:
func_rewards = mx.array(reward_func( func_rewards = mx.array(
reward_func(
prompts=expanded_prompts, prompts=expanded_prompts,
completions=all_completion_texts, completions=all_completion_texts,
answer=expanded_answers answer=expanded_answers,
)) )
)
all_func_rewards.append(func_rewards) all_func_rewards.append(func_rewards)
# Stack rewards to shape (num_samples, num_funcs) # Stack rewards to shape (num_samples, num_funcs)
@ -422,25 +503,39 @@ def grpo_loss(
std_reward = mx.std(prompt_rewards) std_reward = mx.std(prompt_rewards)
# Find indices for this prompt # Find indices for this prompt
indices = [j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i]] indices = [
j
for j, idx in enumerate(batch_indices)
if idx == unique_prompt_indices[i]
]
for j, idx in enumerate(indices): for j, idx in enumerate(indices):
advantages[idx] = (prompt_rewards[j] - mean_reward) / (std_reward + epsilon) advantages[idx] = (prompt_rewards[j] - mean_reward) / (
std_reward + epsilon
)
else: else:
# If only one sample, advantage is 0 # If only one sample, advantage is 0
idx = batch_indices.index(unique_prompt_indices[i]) idx = batch_indices.index(unique_prompt_indices[i])
advantages[idx] = 0.0 advantages[idx] = 0.0
# Compute KL divergence using Schulman's approximator # Compute KL divergence using Schulman's approximator
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1 kl_div = (
mx.exp(ref_token_log_probs - token_log_probs)
- (ref_token_log_probs - token_log_probs)
- 1
)
# Create mask for valid tokens # Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio # Compute policy ratio
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))) policy_ratio = mx.exp(
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
)
# Compute per-token loss # Compute per-token loss
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask) per_token_loss = -(
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
)
# Average over tokens # Average over tokens
sequence_sums = per_token_loss.sum(axis=1) sequence_sums = per_token_loss.sum(axis=1)
@ -454,25 +549,33 @@ def grpo_loss(
reward_metrics = {} reward_metrics = {}
for i, reward_func in enumerate(reward_funcs): for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__ func_name = reward_func.__name__
func_rewards = mx.array(reward_func( func_rewards = mx.array(
reward_func(
prompts=expanded_prompts, prompts=expanded_prompts,
completions=all_completion_texts, completions=all_completion_texts,
answer=expanded_answers answer=expanded_answers,
)) )
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) )
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) reward_metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
reward_metrics[f"{func_name}_std"] = mx.std(func_rewards)
grouped_rewards_mean = mx.array(
grouped_rewards_mean = mx.array([mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt]) [mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt]
grouped_rewards_std = mx.array([mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt]) )
grouped_rewards_std = mx.array(
[
mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1)
for rewards in rewards_by_prompt
]
)
metrics = { metrics = {
'total_rewards_mean': mx.mean(rewards), "total_rewards_mean": mx.mean(rewards),
'total_rewards_std': mx.std(rewards), "total_rewards_std": mx.std(rewards),
'grouped_rewards_mean': mx.mean(grouped_rewards_mean), "grouped_rewards_mean": mx.mean(grouped_rewards_mean),
'grouped_rewards_std': mx.mean(grouped_rewards_std), "grouped_rewards_std": mx.mean(grouped_rewards_std),
'kl': mean_kl, "kl": mean_kl,
**reward_metrics **reward_metrics,
} }
if is_validation and all_completion_texts: if is_validation and all_completion_texts:
@ -500,8 +603,10 @@ def grpo_loss(
print("\n" + "=" * 10 + "\n") print("\n" + "=" * 10 + "\n")
# Only try to extract if r1_extract_xml_answer is defined # Only try to extract if r1_extract_xml_answer is defined
if 'r1_extract_xml_answer' in globals(): if "r1_extract_xml_answer" in globals():
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}") print(
f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}"
)
print("\n" + "=" * 35 + "\n") print("\n" + "=" * 35 + "\n")
mx.metal.clear_cache() mx.metal.clear_cache()
@ -511,7 +616,9 @@ def grpo_loss(
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False): def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") raise ValueError(
"Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples"
)
def length_key(i): def length_key(i):
return len(dataset[i][0]) + len(dataset[i][1]) return len(dataset[i][0]) + len(dataset[i][1])
@ -534,7 +641,8 @@ def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
while True: while True:
indices = ( indices = (
np.random.permutation(list(batch_index_generator())) if train np.random.permutation(list(batch_index_generator()))
if train
else batch_index_generator() else batch_index_generator()
) )
@ -576,10 +684,10 @@ def evaluate_grpo(
r1_int_reward_func, r1_int_reward_func,
r1_strict_format_reward_func, r1_strict_format_reward_func,
r1_soft_format_reward_func, r1_soft_format_reward_func,
r1_count_xml r1_count_xml,
], ],
loss_fn: callable = grpo_loss, loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches iterate_batches: callable = iterate_grpo_batches,
): ):
all_losses = 0 all_losses = 0
ntokens = 0 ntokens = 0
@ -606,7 +714,7 @@ def evaluate_grpo(
ref_model=ref_model, ref_model=ref_model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
is_validation=True is_validation=True,
) )
all_losses += losses * toks all_losses += losses * toks
@ -642,14 +750,16 @@ def train_grpo(
r1_int_reward_func, r1_int_reward_func,
r1_strict_format_reward_func, r1_strict_format_reward_func,
r1_soft_format_reward_func, r1_soft_format_reward_func,
r1_count_xml r1_count_xml,
], ],
args: GRPOTrainingArgs = GRPOTrainingArgs(), args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss_fn: callable = grpo_loss, loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches, iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None, training_callback: TrainingCallback = None,
): ):
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}") print(
f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}"
)
world = mx.distributed.init() world = mx.distributed.init()
world_size = world.size() world_size = world.size()
rank = world.rank() rank = world.rank()
@ -672,7 +782,7 @@ def train_grpo(
epsilon=args.epsilon, epsilon=args.epsilon,
ref_model=ref_model, ref_model=ref_model,
max_tokens=args.max_completion_length, max_tokens=args.max_completion_length,
temperature=args.temperature temperature=args.temperature,
) )
grad = average_gradients(grad) grad = average_gradients(grad)
@ -688,16 +798,16 @@ def train_grpo(
steps = 0 steps = 0
trained_tokens = 0 trained_tokens = 0
accumulated_metrics = { accumulated_metrics = {
'total_rewards_mean': 0, "total_rewards_mean": 0,
'total_rewards_std': 0, "total_rewards_std": 0,
'grouped_rewards_mean': 0, "grouped_rewards_mean": 0,
'grouped_rewards_std': 0, "grouped_rewards_std": 0,
'kl': 0 "kl": 0,
} }
for reward_func in reward_funcs: for reward_func in reward_funcs:
func_name = reward_func.__name__ func_name = reward_func.__name__
accumulated_metrics[f'{func_name}_mean'] = 0 accumulated_metrics[f"{func_name}_mean"] = 0
accumulated_metrics[f'{func_name}_std'] = 0 accumulated_metrics[f"{func_name}_std"] = 0
start = time.perf_counter() start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
@ -746,18 +856,19 @@ def train_grpo(
) )
print( print(
f"Iter {it}: {val_metrics_str}, " f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s",
f"Val took {val_time:.3f}s",
flush=True, flush=True,
) )
if training_callback is not None: if training_callback is not None:
training_callback.on_val_loss_report({ training_callback.on_val_loss_report(
{
"iteration": it, "iteration": it,
"val_loss": val_loss, "val_loss": val_loss,
**{f"val_{k}": v for k, v in val_metrics.items()}, **{f"val_{k}": v for k, v in val_metrics.items()},
"val_time": val_time, "val_time": val_time,
}) }
)
start = time.perf_counter() start = time.perf_counter()
@ -776,7 +887,9 @@ def train_grpo(
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size() train_loss /= steps * mx.distributed.init().size()
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()} avg_metrics = {
k: v / (steps * world_size) for k, v in accumulated_metrics.items()
}
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / (stop - start)
@ -811,7 +924,8 @@ def train_grpo(
) )
if training_callback is not None: if training_callback is not None:
training_callback.on_train_loss_report({ training_callback.on_train_loss_report(
{
"iteration": it, "iteration": it,
"train_loss": train_loss, "train_loss": train_loss,
**{f"train_{k}": v for k, v in avg_metrics.items()}, **{f"train_{k}": v for k, v in avg_metrics.items()},
@ -820,7 +934,8 @@ def train_grpo(
"tokens_per_second": tokens_sec, "tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens, "trained_tokens": trained_tokens,
"peak_memory": peak_mem, "peak_memory": peak_mem,
}) }
)
losses = 0 losses = 0
n_tokens = 0 n_tokens = 0