This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 12:59:41 +01:00
parent c817743333
commit 3dfb21267b
2 changed files with 432 additions and 297 deletions

View File

@ -1,6 +1,5 @@
from typing import List, Optional, Callable
import re
from typing import Callable, List, Optional
RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]]
@ -14,19 +13,30 @@ def r1_extract_xml_answer(text: str) -> str:
print("r1_extract_xml_answer returned empty string")
return ""
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_int_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_accuracy_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions or not answer:
return [0.0] * len(prompts)
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
return [
2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)
]
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_soft_format_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions:
return [0.0] * len(prompts)
@ -41,9 +51,13 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
answer_start = completion.find("<answer>")
answer_end = completion.find("</answer>")
if (reason_start != -1 and reason_end != -1 and
answer_start != -1 and answer_end != -1 and
reason_start < reason_end < answer_start < answer_end):
if (
reason_start != -1
and reason_end != -1
and answer_start != -1
and answer_end != -1
and reason_start < reason_end < answer_start < answer_end
):
reason_content = completion[reason_start + 13 : reason_end].strip()
answer_content = completion[answer_start + 8 : answer_end].strip()
if reason_content and answer_content:
@ -52,14 +66,20 @@ def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, *
scores.append(0.0)
return scores
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_strict_format_reward_func(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions:
return [0.0] * len(prompts)
pattern = r"<think> .*? </think><answer> .*? </answer>"
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
return [0.5 if match else 0.0 for match in matches]
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
def r1_count_xml(
prompts: list, completions: list, answer: list, **kwargs
) -> list[float]:
if not completions:
return [0.0] * len(prompts)
scores = []

View File

@ -1,19 +1,28 @@
# Copyright © 2024 Apple Inc.
from typing import List, Optional, Tuple, Generator, Callable, Any
import time
from dataclasses import dataclass, field
from pathlib import Path
import time
from typing import Any, Callable, Generator, List, Optional, Tuple
from mlx.utils import tree_flatten
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients
from ..utils import generate_step
from ..models import cache
from ..utils import generation_stream
from .grpo_reward_functions import (
RewardFunctions,
r1_accuracy_reward_func,
r1_count_xml,
r1_extract_xml_answer,
r1_int_reward_func,
r1_soft_format_reward_func,
r1_strict_format_reward_func,
)
from .trainer import TrainingArgs, TrainingCallback, average_gradients, grad_checkpoint
@dataclass
class GRPOTrainingArgs(TrainingArgs):
@ -21,9 +30,7 @@ class GRPOTrainingArgs(TrainingArgs):
default=4,
metadata={"help": "Number of responses per prompt."},
)
beta: float = field(
default=0.1, metadata={"help": "KL penalty coefficient."}
)
beta: float = field(default=0.1, metadata={"help": "KL penalty coefficient."})
epsilon: float = field(
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
)
@ -34,40 +41,142 @@ class GRPOTrainingArgs(TrainingArgs):
default=None,
metadata={
"help": "Path to reference model weights. If None, uses the same model."
}
},
)
temperature: float = field(
default=1.0,
metadata={
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
}
},
)
reward_weights: Optional[List[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
}
},
)
def generate_step(
prompt: mx.array,
model: nn.Module,
*,
max_tokens: int = 256,
sampler: Optional[Callable[mx.array, mx.array]] = None,
logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
prefill_step_size: int = 512,
prompt_progress_callback: Optional[Callable[int, int]] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
"""
A generator producing token ids based on the given prompt from the model.
Args:
prompt (mx.array): The input prompt.
model (nn.Module): The model to use for generation.
max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite
generator. Default: ``256``.
sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a
token from a vector of log probabilities. Default: ``None``.
logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional):
A list of functions that take tokens and logits and return the processed
logits. Default: ``None``.
max_kv_size (int, optional): Maximum size of the key-value cache. Old
entries (except the first 4 tokens) will be overwritten.
prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if
provided, the cache will be updated in place.
prefill_step_size (int): Step size for processing the prompt.
kv_bits (int, optional): Number of bits to use for KV cache quantization.
None implies no cache quantization. Default: ``None``.
kv_group_size (int): Group size for KV cache quantization. Default: ``64``.
quantized_kv_start (int): Step to begin using a quantized KV cache.
when ``kv_bits`` is non-None. Default: ``0``.
prompt_prorgress_callback (Callable[int, int]): A call-back which takes the
prompt tokens processed so far and the total number of prompt tokens.
Yields:
Tuple[mx.array, mx.array]: One token and a vector of log probabilities.
"""
y = prompt
tokens = None
# Create the KV cache for generation
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(
model,
max_kv_size=max_kv_size,
)
elif len(prompt_cache) != len(model.layers):
raise ValueError("Wrong number of layers in the prompt cache.")
prompt_progress_callback = prompt_progress_callback or (lambda *_: None)
sampler = sampler or (lambda x: mx.argmax(x, axis=-1))
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processors:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
y = sampler(logprobs)
return y, logprobs.squeeze(0)
with mx.stream(generation_stream):
total_prompt_tokens = y.size
prompt_processed_tokens = 0
while y.size > prefill_step_size:
model(y[:prefill_step_size][None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens)
prompt_processed_tokens += prefill_step_size
y = y[prefill_step_size:]
mx.metal.clear_cache()
y, logprobs = _step(y)
mx.eval(y, logprobs)
n = 0
while True:
if n != max_tokens:
next_y, next_logprobs = _step(y)
mx.eval(next_y, next_logprobs)
if n == 0:
mx.eval(y)
prompt_progress_callback(total_prompt_tokens, total_prompt_tokens)
if n == max_tokens:
break
yield y.item(), logprobs
if n % 256 == 0:
mx.metal.clear_cache()
y, logprobs = next_y, next_logprobs
n += 1
def generate_grpo(
model: nn.Module,
prompts,
max_tokens,
tokenizer,
group_size,
is_training=False,
end_token: str = "</answer>",
temperature: float = 0.8,
batch_size: int = 1
batch_size: int = 1,
):
# Store original training state
was_training = model.training
# Set model to eval mode for generation
model.eval()
try:
import time
start_time = time.time()
if len(prompts.shape) == 1:
prompts = prompts[None, :]
if prompts.shape[1] == 0:
@ -79,113 +188,84 @@ def generate_grpo(
results = []
mx.eval(expanded_prompts)
print(f"Setup time: {time.time() - start_time:.2f}s")
print(f"Generating {total_samples} samples with max_tokens={max_tokens}")
total_tokens_generated = 0
generation_start_time = time.time()
# Process in batches
for batch_start in range(0, total_samples, batch_size):
batch_end = min(batch_start + batch_size, total_samples)
if is_training:
# Training-specific generation logic
batch_inputs = expanded_prompts[batch_start:batch_end]
batch_tokens = [[] for _ in range(batch_end - batch_start)]
prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
# Initial forward pass
for i, prompt in enumerate(batch_inputs):
logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
logits_temp = logits / temperature
next_token = mx.random.categorical(logits_temp)
token = next_token.item()
batch_tokens[i].append(token)
del logits, logits_temp, next_token
mx.eval([tokens[-1] for tokens in batch_tokens])
mx.metal.clear_cache()
active_indices = [i for i in range(len(batch_tokens)) if batch_tokens[i][-1] != tokenizer.eos_token_id]
# Generate remaining tokens
for _ in range(max_tokens - 1):
if not active_indices:
break
next_active = []
for idx in active_indices:
current_input = mx.array([batch_tokens[idx][-1]])
logits = model(current_input[None], cache=prompt_caches[idx])[:, -1]
logits_temp = logits / temperature
next_token = mx.random.categorical(logits_temp)
token = next_token.item()
batch_tokens[idx].append(token)
# Check for end conditions
is_end = False
if len(batch_tokens[idx]) >= len(end_sequence):
test_sequence = batch_tokens[idx][-len(end_sequence):]
is_end = mx.array_equal(mx.array(test_sequence), end_sequence)
if not (is_end or token == tokenizer.eos_token_id):
next_active.append(idx)
del logits, logits_temp, next_token, current_input
mx.eval([tokens[-1] for tokens in batch_tokens])
mx.metal.clear_cache()
active_indices = next_active
# Clean up caches
for pc in prompt_caches:
del pc
# Process results
for tokens in batch_tokens:
if tokens:
# Truncate at end token if present
for i in range(len(tokens) - len(end_sequence) + 1):
if mx.array_equal(
mx.array(tokens[i:i+len(end_sequence)]),
end_sequence
):
tokens = tokens[:i+len(end_sequence)]
break
if tokens and tokens[-1] == tokenizer.eos_token_id:
tokens = tokens[:-1]
if tokens:
results.append(mx.array(tokens))
del batch_inputs, batch_tokens, prompt_caches
mx.metal.clear_cache()
else:
# Non-training mode with batched processing
for idx in range(batch_start, batch_end):
current_tokens = []
generator = generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens,
sampler=lambda x: mx.random.categorical(x / temperature)
batch_time = time.time()
print(
f"Starting batch {batch_start//batch_size + 1}/{(total_samples + batch_size - 1)//batch_size}: samples {batch_start}-{batch_end-1}"
)
for token, _ in generator:
test_sequence = current_tokens + [token]
if (len(test_sequence) >= len(end_sequence) and
mx.array_equal(
mx.array(test_sequence[-len(end_sequence):]),
end_sequence
)):
# Custom sampler function that handles temperature
def temp_sampler(logits):
return mx.random.categorical(logits / temperature)
# Batched processing
for idx in range(batch_start, batch_end):
sample_start_time = time.time()
current_tokens = []
prompt_cache = cache.make_prompt_cache(model)
# The generate_step function yields one token at a time
# We'll collect tokens until we hit max_tokens or a stopping condition
for i, (token, _) in enumerate(
generate_step(
expanded_prompts[idx],
model,
max_tokens=max_tokens, # This is the maximum number of steps
sampler=temp_sampler,
prompt_cache=prompt_cache,
)
):
print(token)
current_tokens.append(token)
# Check for end token
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
mx.array(current_tokens[-len(end_sequence) :]), end_sequence
):
break
# Check for EOS token
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
# Check if we've reached the maximum number of tokens
if i >= max_tokens - 1:
break
if current_tokens:
results.append(mx.array(current_tokens))
total_tokens_generated += len(current_tokens)
sample_time = time.time() - sample_start_time
tokens_per_second = (
len(current_tokens) / sample_time if sample_time > 0 else 0
)
print(
f" Sample {idx}: Generated {len(current_tokens)} tokens in {sample_time:.2f}s ({tokens_per_second:.2f} tokens/sec)"
)
batch_time = time.time() - batch_time
print(f"Batch completed in {batch_time:.2f}s")
mx.metal.clear_cache()
generation_time = time.time() - generation_start_time
avg_tokens_per_second = (
total_tokens_generated / generation_time if generation_time > 0 else 0
)
print(
f"Generation complete: {total_tokens_generated} tokens in {generation_time:.2f}s"
)
print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec")
mx.eval(results)
return results
@ -193,10 +273,6 @@ def generate_grpo(
print(f"Generation error: {str(e)}")
return None
finally:
# Don't restore training mode - let the caller handle it
pass
def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16)
@ -209,8 +285,7 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis(
log_probs,
seq_targets.reshape(seq_len, 1), axis=-1
log_probs, seq_targets.reshape(seq_len, 1), axis=-1
).squeeze(-1)
per_token_logps.append(token_log_probs)
mx.eval(logits)
@ -230,7 +305,7 @@ def grpo_loss(
temperature: float = 0.8,
reward_weights: Optional[List[float]] = None,
is_validation: bool = False,
batch_size: int = 1
batch_size: int = 1,
):
prompt_tokens, _, prompt_text, answer_text = batch
total_samples = len(prompt_tokens)
@ -239,6 +314,12 @@ def grpo_loss(
all_completion_texts = []
batch_indices = [] # Keep track of which batch each completion belongs to
# Store original training state
was_training = model.training
# Set model to eval mode for generation
model.eval()
# Process in smaller batches
for i in range(0, total_samples, batch_size):
# Get actual batch size for this iteration (might be smaller for the last batch)
@ -257,7 +338,6 @@ def grpo_loss(
prompt_tensor = mx.array(padded_prompts)
try:
if is_validation:
completions = generate_grpo(
model,
prompt_tensor,
@ -265,26 +345,16 @@ def grpo_loss(
tokenizer,
group_size,
temperature=temperature,
batch_size=current_batch_size
)
model.train()
else:
completions = generate_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer,
group_size,
is_training=True,
temperature=temperature,
batch_size=current_batch_size
batch_size=current_batch_size,
)
if completions is not None:
for j, completion_ids in enumerate(completions):
# Calculate which prompt this completion belongs to
prompt_idx = i + (j // group_size)
if prompt_idx < total_samples: # Make sure we don't go out of bounds
if (
prompt_idx < total_samples
): # Make sure we don't go out of bounds
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
@ -294,12 +364,19 @@ def grpo_loss(
print(f"Generation error: {e}")
continue
# Restore original training state if we're not in validation mode
if not is_validation and was_training:
model.train()
mx.metal.clear_cache()
# If we didn't generate any completions, return early
if not all_completions:
raise ValueError("No completions were generated. Please check your model and inputs.")
raise ValueError(
"No completions were generated. Please check your model and inputs."
)
# The rest of the function remains the same
# Create expanded prompts and answers based on actual generated completions
expanded_answers = []
expanded_prompts = []
@ -341,7 +418,9 @@ def grpo_loss(
if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
padded_ids = mx.concatenate([completion_ids, padding])
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
mask = mx.concatenate(
[mx.ones_like(completion_ids), mx.zeros_like(padding)]
)
else:
padded_ids = completion_ids
mask = mx.ones_like(completion_ids)
@ -381,11 +460,13 @@ def grpo_loss(
# Collect rewards from each function separately
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(
func_rewards = mx.array(
reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers
))
answer=expanded_answers,
)
)
all_func_rewards.append(func_rewards)
# Stack rewards to shape (num_samples, num_funcs)
@ -422,25 +503,39 @@ def grpo_loss(
std_reward = mx.std(prompt_rewards)
# Find indices for this prompt
indices = [j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i]]
indices = [
j
for j, idx in enumerate(batch_indices)
if idx == unique_prompt_indices[i]
]
for j, idx in enumerate(indices):
advantages[idx] = (prompt_rewards[j] - mean_reward) / (std_reward + epsilon)
advantages[idx] = (prompt_rewards[j] - mean_reward) / (
std_reward + epsilon
)
else:
# If only one sample, advantage is 0
idx = batch_indices.index(unique_prompt_indices[i])
advantages[idx] = 0.0
# Compute KL divergence using Schulman's approximator
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
kl_div = (
mx.exp(ref_token_log_probs - token_log_probs)
- (ref_token_log_probs - token_log_probs)
- 1
)
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
policy_ratio = mx.exp(
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
)
# Compute per-token loss
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
per_token_loss = -(
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
)
# Average over tokens
sequence_sums = per_token_loss.sum(axis=1)
@ -454,25 +549,33 @@ def grpo_loss(
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
func_rewards = mx.array(reward_func(
func_rewards = mx.array(
reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers
))
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
answer=expanded_answers,
)
)
reward_metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
reward_metrics[f"{func_name}_std"] = mx.std(func_rewards)
grouped_rewards_mean = mx.array([mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt])
grouped_rewards_std = mx.array([mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt])
grouped_rewards_mean = mx.array(
[mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt]
)
grouped_rewards_std = mx.array(
[
mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1)
for rewards in rewards_by_prompt
]
)
metrics = {
'total_rewards_mean': mx.mean(rewards),
'total_rewards_std': mx.std(rewards),
'grouped_rewards_mean': mx.mean(grouped_rewards_mean),
'grouped_rewards_std': mx.mean(grouped_rewards_std),
'kl': mean_kl,
**reward_metrics
"total_rewards_mean": mx.mean(rewards),
"total_rewards_std": mx.std(rewards),
"grouped_rewards_mean": mx.mean(grouped_rewards_mean),
"grouped_rewards_std": mx.mean(grouped_rewards_std),
"kl": mean_kl,
**reward_metrics,
}
if is_validation and all_completion_texts:
@ -500,8 +603,10 @@ def grpo_loss(
print("\n" + "=" * 10 + "\n")
# Only try to extract if r1_extract_xml_answer is defined
if 'r1_extract_xml_answer' in globals():
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
if "r1_extract_xml_answer" in globals():
print(
f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}"
)
print("\n" + "=" * 35 + "\n")
mx.metal.clear_cache()
@ -511,7 +616,9 @@ def grpo_loss(
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
raise ValueError(
"Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples"
)
def length_key(i):
return len(dataset[i][0]) + len(dataset[i][1])
@ -534,7 +641,8 @@ def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
while True:
indices = (
np.random.permutation(list(batch_index_generator())) if train
np.random.permutation(list(batch_index_generator()))
if train
else batch_index_generator()
)
@ -576,10 +684,10 @@ def evaluate_grpo(
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml
r1_count_xml,
],
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches
iterate_batches: callable = iterate_grpo_batches,
):
all_losses = 0
ntokens = 0
@ -606,7 +714,7 @@ def evaluate_grpo(
ref_model=ref_model,
temperature=temperature,
max_tokens=max_tokens,
is_validation=True
is_validation=True,
)
all_losses += losses * toks
@ -642,14 +750,16 @@ def train_grpo(
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml
r1_count_xml,
],
args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None,
):
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
print(
f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}"
)
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
@ -672,7 +782,7 @@ def train_grpo(
epsilon=args.epsilon,
ref_model=ref_model,
max_tokens=args.max_completion_length,
temperature=args.temperature
temperature=args.temperature,
)
grad = average_gradients(grad)
@ -688,16 +798,16 @@ def train_grpo(
steps = 0
trained_tokens = 0
accumulated_metrics = {
'total_rewards_mean': 0,
'total_rewards_std': 0,
'grouped_rewards_mean': 0,
'grouped_rewards_std': 0,
'kl': 0
"total_rewards_mean": 0,
"total_rewards_std": 0,
"grouped_rewards_mean": 0,
"grouped_rewards_std": 0,
"kl": 0,
}
for reward_func in reward_funcs:
func_name = reward_func.__name__
accumulated_metrics[f'{func_name}_mean'] = 0
accumulated_metrics[f'{func_name}_std'] = 0
accumulated_metrics[f"{func_name}_mean"] = 0
accumulated_metrics[f"{func_name}_std"] = 0
start = time.perf_counter()
for it, batch in zip(
@ -746,18 +856,19 @@ def train_grpo(
)
print(
f"Iter {it}: {val_metrics_str}, "
f"Val took {val_time:.3f}s",
f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
training_callback.on_val_loss_report({
training_callback.on_val_loss_report(
{
"iteration": it,
"val_loss": val_loss,
**{f"val_{k}": v for k, v in val_metrics.items()},
"val_time": val_time,
})
}
)
start = time.perf_counter()
@ -776,7 +887,9 @@ def train_grpo(
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
avg_metrics = {
k: v / (steps * world_size) for k, v in accumulated_metrics.items()
}
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
@ -811,7 +924,8 @@ def train_grpo(
)
if training_callback is not None:
training_callback.on_train_loss_report({
training_callback.on_train_loss_report(
{
"iteration": it,
"train_loss": train_loss,
**{f"train_{k}": v for k, v in avg_metrics.items()},
@ -820,7 +934,8 @@ def train_grpo(
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
})
}
)
losses = 0
n_tokens = 0