mlx-examples/llms/mlx_lm/tuner/grpo_trainer.py

823 lines
28 KiB
Python
Raw Normal View History

2025-01-31 06:55:34 +08:00
# Copyright © 2024 Apple Inc.
2025-03-05 19:59:41 +08:00
import time
2025-01-31 06:55:34 +08:00
from dataclasses import dataclass, field
from pathlib import Path
2025-03-05 19:59:41 +08:00
from typing import Any, Callable, Generator, List, Optional, Tuple
2025-01-31 06:55:34 +08:00
import mlx.core as mx
import mlx.nn as nn
import numpy as np
2025-03-05 19:59:41 +08:00
from mlx.utils import tree_flatten
2025-01-31 06:55:34 +08:00
from ..models import cache
2025-03-05 19:59:41 +08:00
from ..utils import generation_stream
from .grpo_reward_functions import (
RewardFunctions,
r1_accuracy_reward_func,
r1_count_xml,
r1_extract_xml_answer,
r1_int_reward_func,
r1_soft_format_reward_func,
r1_strict_format_reward_func,
)
from .trainer import TrainingArgs, TrainingCallback, average_gradients, grad_checkpoint
2025-01-31 06:55:34 +08:00
@dataclass
class GRPOTrainingArgs(TrainingArgs):
group_size: int = field(
default=4,
2025-02-01 04:10:44 +08:00
metadata={"help": "Number of responses per prompt."},
2025-01-31 06:55:34 +08:00
)
2025-03-05 19:59:41 +08:00
beta: float = field(default=0.1, metadata={"help": "KL penalty coefficient."})
2025-01-31 06:55:34 +08:00
epsilon: float = field(
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
)
2025-02-04 16:18:45 +08:00
max_completion_length: int = field(
default=512, metadata={"help": "Number of Generations."}
)
2025-01-31 06:55:34 +08:00
reference_model_path: str = field(
default=None,
metadata={
"help": "Path to reference model weights. If None, uses the same model."
2025-03-05 19:59:41 +08:00
},
2025-01-31 06:55:34 +08:00
)
temperature: float = field(
default=1.0,
metadata={
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
2025-03-05 19:59:41 +08:00
},
)
reward_weights: Optional[List[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
2025-03-05 19:59:41 +08:00
},
)
2025-01-31 06:55:34 +08:00
2025-02-05 18:30:21 +08:00
2025-03-09 07:16:40 +08:00
def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16)
logits = logits[:, :-1, :]
targets = inputs[:, 1:]
per_token_logps = []
for i in range(logits.shape[0]):
seq_len = int(lengths[i]) - 1
seq_logits = logits[i, :seq_len]
seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis(
log_probs, seq_targets.reshape(seq_len, 1), axis=-1
).squeeze(-1)
per_token_logps.append(token_log_probs)
mx.eval(logits)
return per_token_logps
2025-03-05 19:59:41 +08:00
def generate_step(
prompt: mx.array,
model: nn.Module,
max_tokens: int = 256,
2025-03-09 07:16:40 +08:00
sampler: Optional[Callable] = None,
logits_processors: Optional[List[Callable]] = None,
2025-03-05 19:59:41 +08:00
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
tokens = None
2025-03-09 07:16:40 +08:00
y = prompt
2025-03-05 19:59:41 +08:00
if prompt_cache is None:
2025-03-09 07:16:40 +08:00
prompt_cache = cache.make_prompt_cache(model, max_kv_size=max_kv_size)
2025-03-05 19:59:41 +08:00
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processors:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
2025-03-09 07:16:40 +08:00
next_token = sampler(logprobs)
return mx.stop_gradient(next_token), mx.stop_gradient(logprobs.squeeze(0))
try:
with mx.stream(generation_stream):
y, logprobs = _step(y)
mx.eval(y, logprobs)
for n in range(max_tokens):
yield y.item(), logprobs
2025-03-05 19:59:41 +08:00
next_y, next_logprobs = _step(y)
mx.eval(next_y, next_logprobs)
2025-03-09 07:16:40 +08:00
y, logprobs = next_y, next_logprobs
if (n + 1) % 32 == 0:
mx.metal.clear_cache()
finally:
mx.metal.clear_cache()
2025-03-05 19:59:41 +08:00
def generate_grpo(
2025-02-28 23:02:40 +08:00
model: nn.Module,
tokenizer,
2025-03-09 07:16:40 +08:00
prompt_tokens,
max_tokens: int,
group_size: int,
2025-02-28 23:02:40 +08:00
end_token: str = "</answer>",
temperature: float = 0.8,
2025-03-05 19:59:41 +08:00
batch_size: int = 1,
2025-02-28 23:02:40 +08:00
):
2025-02-22 07:21:47 +08:00
try:
2025-03-02 05:23:33 +08:00
end_sequence = mx.array(tokenizer.encode(end_token))
2025-03-09 07:16:40 +08:00
total_samples = len(prompt_tokens)
all_completions = []
all_completion_texts = []
batch_indices = []
def temp_sampler(logits):
return mx.random.categorical(logits / temperature)
for i in range(0, total_samples, batch_size):
current_batch_size = min(batch_size, total_samples - i)
batch_prompts = prompt_tokens[i : i + current_batch_size]
max_prompt_len = max(len(p) for p in batch_prompts)
padded_prompts = []
for prompt in batch_prompts:
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
padded_prompts.append(prompt + padding)
prompt_tensor = mx.stop_gradient(mx.array(padded_prompts))
if len(prompt_tensor.shape) == 1:
prompt_tensor = prompt_tensor[None, :]
if prompt_tensor.shape[1] == 0:
continue
expanded_prompts = mx.repeat(prompt_tensor, group_size, axis=0)
batch_results = []
total_prompt_samples = expanded_prompts.shape[0]
for prompt_idx in range(total_prompt_samples):
2025-03-05 19:59:41 +08:00
current_tokens = []
prompt_cache = cache.make_prompt_cache(model)
2025-03-09 07:16:40 +08:00
for token, _ in generate_step(
expanded_prompts[prompt_idx],
model,
max_tokens=max_tokens,
sampler=temp_sampler,
prompt_cache=prompt_cache,
2025-03-05 19:59:41 +08:00
):
2025-03-05 21:00:51 +08:00
if token == tokenizer.eos_token_id:
break
2025-03-09 07:16:40 +08:00
2025-03-05 19:59:41 +08:00
current_tokens.append(token)
2025-03-09 07:16:40 +08:00
2025-03-05 19:59:41 +08:00
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
2025-03-09 07:16:40 +08:00
mx.array(current_tokens[-len(end_sequence):]), end_sequence
2025-03-05 19:59:41 +08:00
):
break
2025-03-09 07:16:40 +08:00
2025-03-05 19:59:41 +08:00
if current_tokens:
2025-03-09 07:16:40 +08:00
batch_results.append(mx.array(current_tokens))
if batch_results:
for j, completion_ids in enumerate(batch_results):
2025-02-28 23:02:40 +08:00
prompt_idx = i + (j // group_size)
2025-03-05 22:28:12 +08:00
if prompt_idx < total_samples:
2025-02-28 23:02:40 +08:00
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist())
2025-03-09 07:16:40 +08:00
all_completions.append(mx.stop_gradient(completion_ids))
2025-02-28 23:02:40 +08:00
all_completion_texts.append(completion_text)
2025-03-09 07:16:40 +08:00
mx.metal.clear_cache()
2025-03-05 22:28:12 +08:00
2025-03-09 07:16:40 +08:00
finally:
mx.metal.clear_cache()
2025-03-05 22:28:12 +08:00
return all_completions, all_completion_texts, batch_indices
def grpo_loss(
model,
ref_model,
tokenizer,
batch,
2025-03-09 07:16:40 +08:00
completions=None,
completion_texts=None,
batch_indices=None,
2025-03-05 22:28:12 +08:00
reward_funcs: Optional[List[RewardFunctions]] = None,
beta: float = 0.1,
group_size: int = 4,
epsilon: float = 1e-4,
max_tokens: int = 64,
temperature: float = 0.8,
reward_weights: Optional[List[float]] = None,
batch_size: int = 1,
is_validation: bool = False
):
prompt_tokens, _, prompt_text, answer_text = batch
2025-03-09 07:16:40 +08:00
if completions is not None and completion_texts is not None and batch_indices is not None:
all_completions = completions
all_completion_texts = completion_texts
batch_indices = batch_indices
else:
all_completions, all_completion_texts, batch_indices = generate_grpo(
model=model,
tokenizer=tokenizer,
prompt_tokens=prompt_tokens,
max_tokens=max_tokens,
group_size=group_size,
temperature=temperature,
batch_size=batch_size
)
2025-03-05 22:28:12 +08:00
2025-02-28 23:02:40 +08:00
if not all_completions:
2025-03-05 19:59:41 +08:00
raise ValueError(
"No completions were generated. Please check your model and inputs."
)
2025-02-28 23:02:40 +08:00
expanded_answers = []
expanded_prompts = []
2025-03-05 19:59:41 +08:00
2025-02-28 23:02:40 +08:00
unique_prompt_indices = sorted(set(batch_indices))
grouped_completions = {idx: [] for idx in unique_prompt_indices}
2025-03-05 19:59:41 +08:00
2025-02-28 23:02:40 +08:00
for i, completion_idx in enumerate(batch_indices):
grouped_completions[completion_idx].append(i)
2025-03-05 19:59:41 +08:00
2025-02-28 23:02:40 +08:00
ordered_completions = []
ordered_completion_texts = []
ordered_batch_indices = []
2025-03-05 19:59:41 +08:00
2025-02-28 23:02:40 +08:00
for prompt_idx in unique_prompt_indices:
completion_indices = grouped_completions[prompt_idx]
for idx in completion_indices:
ordered_completions.append(all_completions[idx])
ordered_completion_texts.append(all_completion_texts[idx])
ordered_batch_indices.append(prompt_idx)
expanded_prompts.append(prompt_text[prompt_idx])
expanded_answers.append(answer_text[prompt_idx])
2025-03-05 19:59:41 +08:00
2025-02-28 23:02:40 +08:00
all_completions = ordered_completions
all_completion_texts = ordered_completion_texts
batch_indices = ordered_batch_indices
2025-03-05 19:59:41 +08:00
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
attention_masks = []
2025-03-05 19:59:41 +08:00
for completion_ids in all_completions:
2025-03-05 22:28:12 +08:00
completion_tensor = mx.array(completion_ids.tolist())
padding_length = max_length - completion_tensor.shape[0]
if padding_length > 0:
2025-03-05 22:28:12 +08:00
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
padded_ids = mx.concatenate([completion_tensor, padding])
2025-03-05 19:59:41 +08:00
mask = mx.concatenate(
2025-03-05 22:28:12 +08:00
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
2025-03-05 19:59:41 +08:00
)
else:
2025-03-05 22:28:12 +08:00
padded_ids = completion_tensor
mask = mx.ones_like(completion_tensor)
padded_completions.append(padded_ids)
attention_masks.append(mask)
2025-03-05 19:59:41 +08:00
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
2025-01-31 06:55:34 +08:00
lengths = attention_mask.sum(axis=1)
2025-03-05 19:59:41 +08:00
2025-02-04 04:57:26 +08:00
token_log_probs = get_per_token_logps(model, inputs, lengths)
2025-02-05 15:44:06 +08:00
mx.eval(token_log_probs)
2025-02-09 22:41:47 +08:00
if ref_model is None:
2025-02-04 04:57:26 +08:00
ref_token_log_probs = token_log_probs
2025-02-09 22:41:47 +08:00
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
2025-02-12 18:07:53 +08:00
mx.eval(ref_token_log_probs)
2025-03-05 19:59:41 +08:00
2025-02-04 04:57:26 +08:00
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
padded_ref_log_probs = []
for i in range(len(token_log_probs)):
seq_len = token_log_probs[i].shape[0]
2025-02-09 22:41:47 +08:00
padding = mx.zeros((max_len - seq_len,))
2025-03-05 19:59:41 +08:00
2025-02-04 04:57:26 +08:00
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
2025-03-05 19:59:41 +08:00
2025-02-11 02:45:19 +08:00
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
2025-03-05 19:59:41 +08:00
all_func_rewards = []
2025-02-03 15:26:42 +08:00
for reward_func in reward_funcs:
2025-03-05 19:59:41 +08:00
func_rewards = mx.array(
reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers,
)
)
all_func_rewards.append(func_rewards)
rewards = mx.stack(all_func_rewards, axis=1)
if reward_weights is not None:
if len(reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
2025-03-05 19:59:41 +08:00
2025-02-26 22:21:57 +08:00
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
2025-01-31 06:55:34 +08:00
2025-02-28 23:02:40 +08:00
num_unique_prompts = len(unique_prompt_indices)
2025-03-05 19:59:41 +08:00
2025-02-28 23:02:40 +08:00
rewards_by_prompt = [[] for _ in range(num_unique_prompts)]
for i, prompt_idx in enumerate(batch_indices):
prompt_position = unique_prompt_indices.index(prompt_idx)
rewards_by_prompt[prompt_position].append(rewards[i])
2025-03-05 19:59:41 +08:00
2025-02-28 23:02:40 +08:00
advantages = mx.zeros_like(rewards)
for i, prompt_rewards in enumerate(rewards_by_prompt):
2025-03-09 07:16:40 +08:00
if len(prompt_rewards) > 1:
2025-02-28 23:02:40 +08:00
prompt_rewards = mx.array(prompt_rewards)
mean_reward = mx.mean(prompt_rewards)
std_reward = mx.std(prompt_rewards)
2025-03-05 19:59:41 +08:00
indices = [
j
for j, idx in enumerate(batch_indices)
if idx == unique_prompt_indices[i]
]
2025-02-28 23:02:40 +08:00
for j, idx in enumerate(indices):
2025-03-05 19:59:41 +08:00
advantages[idx] = (prompt_rewards[j] - mean_reward) / (
std_reward + epsilon
)
2025-02-28 23:02:40 +08:00
else:
idx = batch_indices.index(unique_prompt_indices[i])
advantages[idx] = 0.0
2025-03-05 19:59:41 +08:00
2025-02-04 02:37:05 +08:00
# Compute KL divergence using Schulman's approximator
2025-03-05 19:59:41 +08:00
kl_div = (
mx.exp(ref_token_log_probs - token_log_probs)
- (ref_token_log_probs - token_log_probs)
- 1
)
2025-02-04 02:37:05 +08:00
# Create mask for valid tokens
2025-01-31 06:55:34 +08:00
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
2025-03-05 19:59:41 +08:00
2025-02-04 02:37:05 +08:00
# Compute policy ratio
2025-03-05 19:59:41 +08:00
policy_ratio = mx.exp(
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
)
2025-02-12 18:07:53 +08:00
# Compute per-token loss
2025-03-05 19:59:41 +08:00
per_token_loss = -(
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
)
2025-02-09 22:41:47 +08:00
# Average over tokens
2025-02-10 00:13:05 +08:00
sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
2025-03-05 19:59:41 +08:00
2025-02-04 02:37:05 +08:00
# Calculate mean KL divergence for metrics
2025-01-31 06:55:34 +08:00
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
2025-02-09 22:41:47 +08:00
2025-02-04 02:37:05 +08:00
# Collect reward metrics
2025-02-03 15:26:42 +08:00
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
2025-02-04 02:47:40 +08:00
func_name = reward_func.__name__
2025-03-05 19:59:41 +08:00
func_rewards = mx.array(
reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers,
)
)
reward_metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
reward_metrics[f"{func_name}_std"] = mx.std(func_rewards)
grouped_rewards_mean = mx.array(
[mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt]
)
grouped_rewards_std = mx.array(
[
mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1)
for rewards in rewards_by_prompt
]
)
2025-01-31 06:55:34 +08:00
metrics = {
2025-03-05 19:59:41 +08:00
"total_rewards_mean": mx.mean(rewards),
"total_rewards_std": mx.std(rewards),
"grouped_rewards_mean": mx.mean(grouped_rewards_mean),
"grouped_rewards_std": mx.mean(grouped_rewards_std),
"kl": mean_kl,
**reward_metrics,
2025-01-31 06:55:34 +08:00
}
2025-02-28 23:02:40 +08:00
if is_validation and all_completion_texts:
2025-02-25 05:20:07 +08:00
print("\n=== Validation Sample Details ===")
2025-03-05 19:59:41 +08:00
2025-03-01 05:07:19 +08:00
# Print the input context (prompt)
last_prompt_idx = batch_indices[-1] if batch_indices else 0
2025-03-05 19:59:41 +08:00
2025-03-01 05:07:19 +08:00
if last_prompt_idx < len(prompt_text):
print(f"\n📋 Raw Prompt:\n{prompt_text[last_prompt_idx]}")
2025-03-05 19:59:41 +08:00
print("\n" + "=" * 10 + "\n")
2025-03-01 05:07:19 +08:00
# Get the actual tokenized prompt that was fed to the model
if last_prompt_idx < len(prompt_tokens):
actual_prompt = tokenizer.decode(prompt_tokens[last_prompt_idx])
print(f"\n🔄 Model Input:\n{actual_prompt}")
2025-03-05 19:59:41 +08:00
print("\n" + "=" * 10 + "\n")
2025-02-25 05:20:07 +08:00
print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
2025-03-05 19:59:41 +08:00
print("\n" + "=" * 10 + "\n")
2025-02-28 23:02:40 +08:00
# Make sure we have a valid index for answer_text
if last_prompt_idx < len(answer_text):
print(f"\n✅ Answer:\n{answer_text[last_prompt_idx]}")
2025-03-05 19:59:41 +08:00
print("\n" + "=" * 10 + "\n")
2025-02-28 23:02:40 +08:00
# Only try to extract if r1_extract_xml_answer is defined
2025-03-05 19:59:41 +08:00
if "r1_extract_xml_answer" in globals():
print(
f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}"
)
print("\n" + "=" * 35 + "\n")
2025-02-05 15:44:06 +08:00
mx.metal.clear_cache()
2025-03-05 19:59:41 +08:00
2025-02-10 00:13:05 +08:00
return loss, sequence_lengths.sum(), metrics
2025-01-31 06:55:34 +08:00
2025-02-12 18:07:53 +08:00
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
2025-03-05 19:59:41 +08:00
raise ValueError(
"Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples"
)
2025-02-12 18:07:53 +08:00
def length_key(i):
return len(dataset[i][0]) + len(dataset[i][1])
2025-03-05 19:59:41 +08:00
2025-02-12 18:07:53 +08:00
idx = sorted(range(len(dataset)), key=length_key)
2025-03-05 19:59:41 +08:00
2025-01-31 06:55:34 +08:00
if len(dataset) < batch_size:
raise ValueError(
2025-02-03 17:08:28 +08:00
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
2025-01-31 06:55:34 +08:00
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
2025-02-12 18:07:53 +08:00
def batch_index_generator():
for i in range(0, len(idx) - batch_size + 1, batch_size):
yield idx[i : i + batch_size : step]
2025-01-31 06:55:34 +08:00
while True:
2025-02-12 18:07:53 +08:00
indices = (
2025-03-05 19:59:41 +08:00
np.random.permutation(list(batch_index_generator()))
if train
2025-02-12 18:07:53 +08:00
else batch_index_generator()
)
2025-03-05 19:59:41 +08:00
2025-02-12 18:07:53 +08:00
for batch_idx in indices:
current_batch = [dataset[j] for j in batch_idx]
2025-03-05 19:59:41 +08:00
2025-02-12 18:07:53 +08:00
prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch]
answers_text = [item[3] for item in current_batch]
if any(len(p) > max_seq_length for p in prompts_tokens):
2025-01-31 06:55:34 +08:00
print(
2025-02-03 17:08:28 +08:00
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
2025-01-31 06:55:34 +08:00
)
2025-02-12 18:07:53 +08:00
yield prompts_tokens, answers_tokens, prompts_text, answers_text
2025-01-31 06:55:34 +08:00
if not train:
break
def evaluate_grpo(
model: nn.Module,
ref_model: Optional[nn.Module],
2025-01-31 06:55:34 +08:00
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
2025-02-03 17:08:28 +08:00
epsilon: float,
2025-01-31 06:55:34 +08:00
group_size: int,
2025-02-17 21:39:38 +08:00
max_seq_length: int,
max_tokens: int,
temperature: float,
2025-02-25 05:20:07 +08:00
reward_funcs: Optional[List[RewardFunctions]] = [
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
2025-03-05 19:59:41 +08:00
r1_count_xml,
2025-02-25 05:20:07 +08:00
],
2025-02-04 16:18:45 +08:00
loss_fn: callable = grpo_loss,
2025-03-05 19:59:41 +08:00
iterate_batches: callable = iterate_grpo_batches,
2025-01-31 06:55:34 +08:00
):
all_losses = 0
ntokens = 0
all_metrics = None
2025-03-05 19:59:41 +08:00
2025-01-31 06:55:34 +08:00
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
2025-03-05 19:59:41 +08:00
2025-01-31 06:55:34 +08:00
for _, batch in zip(
index_iterator,
iterate_batches(
dataset=dataset,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
2025-02-04 16:18:45 +08:00
losses, toks, metrics = loss_fn(
2025-01-31 23:27:31 +08:00
model=model,
tokenizer=tokenizer,
batch=batch,
2025-01-31 23:27:31 +08:00
reward_funcs=reward_funcs,
beta=beta,
group_size=group_size,
2025-02-03 17:08:28 +08:00
epsilon=epsilon,
ref_model=ref_model,
2025-02-17 21:39:38 +08:00
temperature=temperature,
2025-03-05 21:25:55 +08:00
max_tokens=max_tokens,
2025-03-09 07:16:40 +08:00
is_validation=True
2025-01-31 23:27:31 +08:00
)
2025-03-05 19:59:41 +08:00
2025-01-31 23:27:31 +08:00
all_losses += losses * toks
ntokens += toks
2025-03-05 19:59:41 +08:00
2025-01-31 23:54:18 +08:00
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
2025-03-05 19:59:41 +08:00
2025-01-31 23:27:31 +08:00
mx.eval(all_losses, ntokens)
2025-03-05 19:59:41 +08:00
2025-02-11 16:26:43 +08:00
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
2025-01-31 23:54:18 +08:00
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
2025-03-05 19:59:41 +08:00
2025-01-31 23:54:18 +08:00
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item()
2025-03-05 19:59:41 +08:00
2025-01-31 23:54:18 +08:00
return avg_loss, ntokens, avg_metrics
2025-01-31 23:27:31 +08:00
2025-01-31 06:55:34 +08:00
2025-01-31 23:57:43 +08:00
def train_grpo(
model: nn.Module,
ref_model: Optional[nn.Module],
2025-01-31 06:55:34 +08:00
tokenizer,
optimizer,
2025-02-12 18:07:53 +08:00
train_dataset,
val_dataset,
reward_funcs: Optional[List[RewardFunctions]] = [
2025-02-03 16:13:17 +08:00
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
2025-03-05 19:59:41 +08:00
r1_count_xml,
2025-02-03 16:13:17 +08:00
],
2025-01-31 06:55:34 +08:00
args: GRPOTrainingArgs = GRPOTrainingArgs(),
2025-02-04 16:18:45 +08:00
loss_fn: callable = grpo_loss,
2025-02-03 17:08:28 +08:00
iterate_batches: callable = iterate_grpo_batches,
2025-01-31 06:55:34 +08:00
training_callback: TrainingCallback = None,
):
2025-03-05 19:59:41 +08:00
print(
f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}"
)
2025-01-31 06:55:34 +08:00
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
2025-03-09 07:16:40 +08:00
# Extract prompt tokens from the batch
prompt_tokens, targets, prompt_lens, target_lens = batch
# First, generate completions without gradient tracking
# The model will be frozen during this call
all_completions, all_completion_texts, batch_indices = generate_grpo(
model=model,
tokenizer=tokenizer,
prompt_tokens=prompt_tokens,
max_tokens=args.max_completion_length,
group_size=args.group_size,
temperature=args.temperature
)
# Now calculate loss and gradients with pre-generated completions
# We need to update loss_fn to accept these pre-generated completions
(loss, toks, metrics), grad = loss_value_and_grad(
2025-03-05 19:59:41 +08:00
model,
tokenizer=tokenizer,
2025-03-09 07:16:40 +08:00
batch=(prompt_tokens, targets, prompt_lens, target_lens),
completions=all_completions,
completion_texts=all_completion_texts,
batch_indices=batch_indices,
reward_funcs=reward_funcs,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
2025-02-04 02:37:05 +08:00
ref_model=ref_model,
)
2025-01-31 06:55:34 +08:00
grad = average_gradients(grad)
optimizer.update(model, grad)
2025-01-31 23:54:18 +08:00
return loss, toks, metrics
2025-03-05 19:59:41 +08:00
2025-02-04 16:18:45 +08:00
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
2025-01-31 06:55:34 +08:00
losses = 0
n_tokens = 0
steps = 0
trained_tokens = 0
2025-01-31 23:54:18 +08:00
accumulated_metrics = {
2025-03-05 19:59:41 +08:00
"total_rewards_mean": 0,
"total_rewards_std": 0,
"grouped_rewards_mean": 0,
"grouped_rewards_std": 0,
"kl": 0,
2025-01-31 23:54:18 +08:00
}
2025-02-04 16:18:45 +08:00
for reward_func in reward_funcs:
func_name = reward_func.__name__
2025-03-05 19:59:41 +08:00
accumulated_metrics[f"{func_name}_mean"] = 0
accumulated_metrics[f"{func_name}_std"] = 0
2025-01-31 23:54:18 +08:00
2025-01-31 06:55:34 +08:00
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_batches(
dataset=train_dataset,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
2025-01-31 23:57:43 +08:00
val_loss, val_ntokens, val_metrics = evaluate_grpo(
2025-01-31 06:55:34 +08:00
model=model,
dataset=val_dataset,
2025-02-04 16:18:45 +08:00
loss_fn=loss_fn,
2025-02-03 17:08:28 +08:00
ref_model=ref_model,
2025-02-03 16:13:17 +08:00
reward_funcs=reward_funcs,
2025-01-31 06:55:34 +08:00
tokenizer=tokenizer,
2025-02-03 17:08:28 +08:00
group_size=args.group_size,
2025-01-31 06:55:34 +08:00
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
2025-02-17 21:39:38 +08:00
max_tokens=args.max_completion_length,
2025-02-03 17:08:28 +08:00
beta=args.beta,
epsilon=args.epsilon,
temperature=args.temperature,
2025-01-31 06:55:34 +08:00
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
if rank == 0:
2025-02-03 17:08:28 +08:00
val_metrics_str = (
2025-02-25 05:20:07 +08:00
f"Val loss {val_loss:.3f}, "
2025-02-03 17:08:28 +08:00
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
2025-01-31 23:54:18 +08:00
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
2025-02-03 17:08:28 +08:00
f"Val kl {val_metrics['kl']:.3f}"
)
2025-03-05 19:59:41 +08:00
2025-02-04 02:56:11 +08:00
for i, reward_func in enumerate(reward_funcs):
2025-02-03 17:08:28 +08:00
val_metrics_str += (
2025-02-04 02:56:11 +08:00
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
2025-02-11 02:45:19 +08:00
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
2025-02-03 17:08:28 +08:00
)
2025-03-05 19:59:41 +08:00
2025-02-03 17:08:28 +08:00
print(
2025-03-05 19:59:41 +08:00
f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s",
2025-01-31 06:55:34 +08:00
flush=True,
)
if training_callback is not None:
2025-03-05 19:59:41 +08:00
training_callback.on_val_loss_report(
{
"iteration": it,
"val_loss": val_loss,
**{f"val_{k}": v for k, v in val_metrics.items()},
"val_time": val_time,
}
)
2025-01-31 06:55:34 +08:00
start = time.perf_counter()
2025-01-31 23:54:18 +08:00
loss, toks, metrics = step(batch)
losses += loss
2025-01-31 06:55:34 +08:00
n_tokens += toks
steps += 1
2025-02-04 02:37:05 +08:00
2025-03-05 21:40:23 +08:00
mx.metal.clear_cache()
2025-01-31 23:54:18 +08:00
for k, v in metrics.items():
accumulated_metrics[k] += v
2025-02-04 02:37:05 +08:00
2025-01-31 06:55:34 +08:00
mx.eval(state, losses, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
2025-03-05 19:59:41 +08:00
avg_metrics = {
k: v / (steps * world_size) for k, v in accumulated_metrics.items()
}
2025-01-31 06:55:34 +08:00
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
2025-01-31 23:54:18 +08:00
2025-01-31 06:55:34 +08:00
if rank == 0:
2025-02-03 17:08:28 +08:00
train_metrics_str = (
2025-02-25 05:20:07 +08:00
f"Train loss {train_loss:.3f}, "
2025-02-03 17:08:28 +08:00
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
f"KL {avg_metrics['kl']:.3f}"
)
2025-03-05 19:59:41 +08:00
2025-02-04 02:56:11 +08:00
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
2025-02-03 17:08:28 +08:00
train_metrics_str += (
2025-02-04 16:18:45 +08:00
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
2025-02-03 17:08:28 +08:00
)
2025-01-31 06:55:34 +08:00
print(
2025-02-03 17:08:28 +08:00
f"Iter {it}: {train_metrics_str}, "
2025-01-31 06:55:34 +08:00
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
2025-03-05 19:59:41 +08:00
training_callback.on_train_loss_report(
{
"iteration": it,
"train_loss": train_loss,
**{f"train_{k}": v for k, v in avg_metrics.items()},
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
)
2025-01-31 06:55:34 +08:00
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
2025-03-05 19:59:41 +08:00
print(f"Saved final weights to {args.adapter_file}.")