2025-01-31 06:55:34 +08:00
|
|
|
# Copyright © 2024 Apple Inc.
|
|
|
|
|
|
|
|
import time
|
|
|
|
from dataclasses import dataclass, field
|
|
|
|
from pathlib import Path
|
2025-02-03 15:26:42 +08:00
|
|
|
import re
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
import mlx.core as mx
|
|
|
|
import mlx.nn as nn
|
|
|
|
import numpy as np
|
|
|
|
from mlx.utils import tree_flatten
|
|
|
|
|
2025-02-03 17:08:28 +08:00
|
|
|
from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class GRPOTrainingArgs(TrainingArgs):
|
|
|
|
group_size: int = field(
|
|
|
|
default=4,
|
2025-02-01 04:10:44 +08:00
|
|
|
metadata={"help": "Number of responses per prompt."},
|
2025-01-31 06:55:34 +08:00
|
|
|
)
|
|
|
|
beta: float = field(
|
|
|
|
default=0.1, metadata={"help": "KL penalty coefficient."}
|
|
|
|
)
|
|
|
|
epsilon: float = field(
|
|
|
|
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
|
|
|
|
)
|
2025-02-04 16:18:45 +08:00
|
|
|
max_completion_length: int = field(
|
|
|
|
default=512, metadata={"help": "Number of Generations."}
|
|
|
|
)
|
2025-01-31 06:55:34 +08:00
|
|
|
reference_model_path: str = field(
|
|
|
|
default=None,
|
|
|
|
metadata={
|
|
|
|
"help": "Path to reference model weights. If None, uses the same model."
|
|
|
|
}
|
|
|
|
)
|
|
|
|
|
2025-02-05 18:30:21 +08:00
|
|
|
|
2025-02-03 15:26:42 +08:00
|
|
|
def r1_extract_xml_answer(text: str) -> str:
|
|
|
|
"""Extracts the answer from an XML formatted text string."""
|
|
|
|
try:
|
|
|
|
answer = text.split("<answer>")[-1]
|
|
|
|
answer = answer.split("</answer>")[0]
|
|
|
|
return answer.strip()
|
|
|
|
except:
|
2025-02-05 18:30:21 +08:00
|
|
|
print("r1_extract_xml_answer returned empty string")
|
2025-02-03 15:26:42 +08:00
|
|
|
return ""
|
|
|
|
|
2025-02-05 21:38:09 +08:00
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
2025-02-05 18:30:21 +08:00
|
|
|
"""Ensures we always return a list of floats."""
|
2025-02-05 21:38:09 +08:00
|
|
|
if not completions:
|
2025-02-05 18:30:21 +08:00
|
|
|
return [0.0] * len(prompts)
|
2025-02-03 15:26:42 +08:00
|
|
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
2025-02-05 18:30:21 +08:00
|
|
|
return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses]
|
2025-02-03 15:26:42 +08:00
|
|
|
|
2025-02-05 16:48:00 +08:00
|
|
|
def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
2025-02-05 18:30:21 +08:00
|
|
|
"""Ensures we always return a list of floats."""
|
2025-02-05 21:38:09 +08:00
|
|
|
if not completions or not answer:
|
2025-02-05 18:30:21 +08:00
|
|
|
return [0.0] * len(prompts)
|
2025-02-05 16:48:00 +08:00
|
|
|
extracted_responses = [r1_extract_xml_answer(r) for r in completions]
|
2025-02-05 18:30:21 +08:00
|
|
|
return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)]
|
2025-02-05 16:48:00 +08:00
|
|
|
|
2025-02-05 21:38:09 +08:00
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
2025-02-05 18:30:21 +08:00
|
|
|
"""Ensures we always return a list of floats."""
|
2025-02-05 21:38:09 +08:00
|
|
|
if not completions:
|
2025-02-05 18:30:21 +08:00
|
|
|
return [0.0] * len(prompts)
|
2025-02-03 19:05:29 +08:00
|
|
|
pattern = r"<think>.*?</think>\s*<answer>.*?</answer>"
|
2025-02-05 18:30:21 +08:00
|
|
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
2025-02-03 15:26:42 +08:00
|
|
|
return [0.5 if match else 0.0 for match in matches]
|
|
|
|
|
2025-02-05 21:38:09 +08:00
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
2025-02-05 18:30:21 +08:00
|
|
|
"""Ensures we always return a list of floats."""
|
2025-02-05 21:38:09 +08:00
|
|
|
if not completions:
|
2025-02-05 18:30:21 +08:00
|
|
|
return [0.0] * len(prompts)
|
2025-02-03 19:05:29 +08:00
|
|
|
pattern = r"^<think>\n.*?\n</think>\n<answer>\n.*?\n</answer>\n$"
|
2025-02-05 18:30:21 +08:00
|
|
|
matches = [bool(re.search(pattern, r)) if r else False for r in completions]
|
2025-02-03 15:26:42 +08:00
|
|
|
return [0.5 if match else 0.0 for match in matches]
|
|
|
|
|
2025-02-05 21:38:09 +08:00
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]:
|
2025-02-05 18:30:21 +08:00
|
|
|
"""Ensures we always return a list of floats."""
|
2025-02-05 21:38:09 +08:00
|
|
|
if not completions:
|
2025-02-05 18:30:21 +08:00
|
|
|
return [0.0] * len(prompts)
|
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
scores = []
|
|
|
|
for text in completions:
|
2025-02-05 21:38:09 +08:00
|
|
|
if not text:
|
2025-02-05 18:30:21 +08:00
|
|
|
scores.append(0.0)
|
|
|
|
continue
|
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
count = 0.0
|
|
|
|
if text.count("<think>\n") == 1:
|
|
|
|
count += 0.125
|
|
|
|
if text.count("\n</think>\n") == 1:
|
|
|
|
count += 0.125
|
|
|
|
if text.count("\n<answer>\n") == 1:
|
|
|
|
count += 0.125
|
2025-02-05 18:30:21 +08:00
|
|
|
if text.count("\n</answer>\n") == 1:
|
2025-02-03 19:05:29 +08:00
|
|
|
count += 0.125
|
2025-02-05 18:30:21 +08:00
|
|
|
|
|
|
|
# Penalize extra text after </answer>
|
|
|
|
end_text = text.split("\n</answer>\n")[-1]
|
|
|
|
count -= len(end_text) * 0.001 if len(end_text) > 0 else 0
|
|
|
|
|
|
|
|
scores.append(max(0.0, count)) # Ensure non-negative score
|
|
|
|
|
|
|
|
return scores
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
|
2025-02-05 16:48:00 +08:00
|
|
|
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
|
2025-02-05 15:44:06 +08:00
|
|
|
if len(prompt.shape) == 1:
|
|
|
|
prompt = prompt[None, :]
|
|
|
|
if prompt.shape[1] == 0:
|
|
|
|
return None
|
2025-02-05 22:02:12 +08:00
|
|
|
|
2025-02-05 15:47:03 +08:00
|
|
|
end_sequence = tokenizer.encode("</answer>")
|
|
|
|
end_sequence_length = len(end_sequence)
|
2025-02-05 15:44:06 +08:00
|
|
|
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
|
|
|
|
output[:prompt.shape[1]] = prompt[0]
|
|
|
|
current_length = prompt.shape[1]
|
2025-02-05 22:02:12 +08:00
|
|
|
|
2025-02-05 15:44:06 +08:00
|
|
|
try:
|
2025-02-05 22:02:12 +08:00
|
|
|
def sample(logits):
|
|
|
|
if temperature > 0:
|
|
|
|
logits /= temperature
|
|
|
|
logprobs = logits - mx.logsumexp(logits, keepdims=True)
|
|
|
|
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
|
|
|
|
|
2025-02-05 15:44:06 +08:00
|
|
|
for _ in range(max_tokens):
|
|
|
|
current_input = output[:current_length][None, :]
|
|
|
|
logits = model(current_input)
|
|
|
|
token_logits = logits[0, -1]
|
2025-02-05 22:02:12 +08:00
|
|
|
next_token = sample(token_logits)
|
2025-02-05 15:44:06 +08:00
|
|
|
token_value = next_token.item()
|
|
|
|
output[current_length] = token_value
|
|
|
|
current_length += 1
|
|
|
|
|
|
|
|
if token_value == tokenizer.eos_token_id:
|
|
|
|
break
|
|
|
|
|
2025-02-05 15:47:03 +08:00
|
|
|
if current_length >= end_sequence_length:
|
|
|
|
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
|
|
|
|
if last_tokens == end_sequence:
|
|
|
|
break
|
2025-02-05 22:02:12 +08:00
|
|
|
|
2025-02-05 15:44:06 +08:00
|
|
|
if current_length > prompt.shape[1]:
|
2025-02-05 22:02:12 +08:00
|
|
|
return output[:current_length]
|
2025-02-05 15:44:06 +08:00
|
|
|
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Generation error: {str(e)}")
|
|
|
|
return None
|
2025-02-05 22:02:12 +08:00
|
|
|
|
2025-02-05 15:44:06 +08:00
|
|
|
return None
|
|
|
|
|
|
|
|
|
2025-02-04 04:57:26 +08:00
|
|
|
def get_per_token_logps(model, inputs, lengths):
|
2025-02-05 16:48:00 +08:00
|
|
|
logits = model(inputs).astype(mx.float16)
|
|
|
|
logits = logits[:, :-1, :]
|
|
|
|
targets = inputs[:, 1:]
|
2025-02-04 04:57:26 +08:00
|
|
|
|
|
|
|
per_token_logps = []
|
|
|
|
for i in range(logits.shape[0]):
|
2025-02-05 16:48:00 +08:00
|
|
|
seq_len = int(lengths[i]) - 1
|
2025-02-04 04:57:26 +08:00
|
|
|
|
2025-02-05 16:48:00 +08:00
|
|
|
seq_logits = logits[i, :seq_len]
|
|
|
|
seq_targets = targets[i, :seq_len]
|
2025-02-04 04:57:26 +08:00
|
|
|
|
2025-02-05 16:48:00 +08:00
|
|
|
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
2025-02-04 04:57:26 +08:00
|
|
|
|
|
|
|
token_log_probs = mx.take_along_axis(
|
|
|
|
log_probs,
|
|
|
|
seq_targets.reshape(seq_len, 1),
|
|
|
|
axis=-1
|
2025-02-05 16:48:00 +08:00
|
|
|
).squeeze(-1)
|
2025-02-04 04:57:26 +08:00
|
|
|
|
|
|
|
per_token_logps.append(token_log_probs)
|
2025-02-05 15:44:06 +08:00
|
|
|
mx.eval(logits)
|
2025-02-04 04:57:26 +08:00
|
|
|
return per_token_logps
|
|
|
|
|
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
def grpo_loss(
|
2025-02-03 19:05:29 +08:00
|
|
|
model,
|
2025-02-09 22:41:47 +08:00
|
|
|
ref_model,
|
2025-02-03 19:05:29 +08:00
|
|
|
tokenizer,
|
|
|
|
batch,
|
|
|
|
reward_funcs=None,
|
|
|
|
beta=0.1,
|
|
|
|
group_size=4,
|
|
|
|
epsilon=1e-4,
|
2025-02-04 02:37:05 +08:00
|
|
|
max_tokens=64,
|
2025-02-03 19:05:29 +08:00
|
|
|
temperature=1.0
|
|
|
|
):
|
|
|
|
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
|
|
|
|
batch_size = len(prompt_tokens)
|
2025-02-03 15:26:42 +08:00
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Generation logic remains the same
|
2025-01-31 06:55:34 +08:00
|
|
|
all_completions = []
|
2025-02-03 19:05:29 +08:00
|
|
|
all_completion_texts = []
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-05 15:44:06 +08:00
|
|
|
for i in range(0, batch_size, batch_size):
|
|
|
|
batch_prompts = prompt_tokens[i:i+batch_size]
|
|
|
|
for prompt in batch_prompts:
|
|
|
|
prompt_tensor = mx.array(prompt)
|
|
|
|
for _ in range(group_size):
|
|
|
|
try:
|
|
|
|
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
|
|
|
|
if completion_ids is not None:
|
|
|
|
completion_text = tokenizer.decode(completion_ids.tolist())
|
|
|
|
all_completions.append(completion_ids)
|
|
|
|
all_completion_texts.append(completion_text)
|
|
|
|
|
|
|
|
# Clear completion tensors
|
|
|
|
mx.eval(completion_ids)
|
|
|
|
del completion_ids
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Generation error: {e}")
|
2025-02-04 02:37:05 +08:00
|
|
|
continue
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-02-05 15:44:06 +08:00
|
|
|
mx.metal.clear_cache()
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Prepare inputs
|
2025-02-03 19:05:29 +08:00
|
|
|
expanded_answers = []
|
|
|
|
expanded_prompts = []
|
|
|
|
for i in range(batch_size):
|
|
|
|
expanded_answers.extend([answer_text[i]] * group_size)
|
|
|
|
expanded_prompts.extend([prompt_text[i]] * group_size)
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
max_length = max(ids.shape[0] for ids in all_completions)
|
|
|
|
padded_completions = []
|
|
|
|
attention_masks = []
|
|
|
|
|
|
|
|
for completion_ids in all_completions:
|
|
|
|
padding_length = max_length - completion_ids.shape[0]
|
|
|
|
if padding_length > 0:
|
|
|
|
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
|
|
|
|
padded_ids = mx.concatenate([completion_ids, padding])
|
|
|
|
mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)])
|
|
|
|
else:
|
|
|
|
padded_ids = completion_ids
|
|
|
|
mask = mx.ones_like(completion_ids)
|
|
|
|
padded_completions.append(padded_ids)
|
|
|
|
attention_masks.append(mask)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
inputs = mx.stack(padded_completions)
|
|
|
|
attention_mask = mx.stack(attention_masks)
|
2025-01-31 06:55:34 +08:00
|
|
|
lengths = attention_mask.sum(axis=1)
|
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Current policy probabilities
|
2025-02-04 04:57:26 +08:00
|
|
|
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
2025-02-05 15:44:06 +08:00
|
|
|
|
|
|
|
mx.eval(token_log_probs)
|
|
|
|
mx.metal.clear_cache()
|
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Reference policy probabilities
|
2025-02-09 22:41:47 +08:00
|
|
|
if ref_model is None:
|
2025-02-04 04:57:26 +08:00
|
|
|
ref_token_log_probs = token_log_probs
|
2025-02-09 22:41:47 +08:00
|
|
|
else:
|
|
|
|
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
2025-02-04 04:57:26 +08:00
|
|
|
|
|
|
|
max_len = max(x.shape[0] for x in token_log_probs)
|
|
|
|
padded_log_probs = []
|
|
|
|
padded_ref_log_probs = []
|
|
|
|
|
|
|
|
for i in range(len(token_log_probs)):
|
|
|
|
seq_len = token_log_probs[i].shape[0]
|
2025-02-09 22:41:47 +08:00
|
|
|
padding = mx.zeros((max_len - seq_len,))
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-04 04:57:26 +08:00
|
|
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
|
|
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
|
|
|
|
2025-02-09 22:41:47 +08:00
|
|
|
token_log_probs = mx.stack(padded_log_probs).astype(mx.float32)
|
|
|
|
ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Calculate rewards and advantages
|
2025-02-03 15:26:42 +08:00
|
|
|
rewards = mx.zeros((len(all_completions),))
|
|
|
|
for reward_func in reward_funcs:
|
2025-02-03 19:05:29 +08:00
|
|
|
func_rewards = mx.array(reward_func(
|
2025-02-04 02:43:49 +08:00
|
|
|
prompts=expanded_prompts,
|
2025-02-04 02:37:05 +08:00
|
|
|
completions=all_completion_texts,
|
2025-02-04 02:43:49 +08:00
|
|
|
answer=expanded_answers
|
2025-02-03 19:05:29 +08:00
|
|
|
))
|
2025-02-03 15:26:42 +08:00
|
|
|
rewards += func_rewards
|
2025-02-04 02:43:49 +08:00
|
|
|
|
2025-02-03 15:26:42 +08:00
|
|
|
if len(reward_funcs) > 1:
|
|
|
|
rewards /= len(reward_funcs)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Reshape rewards and compute advantages following GRPO formula
|
|
|
|
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
|
|
|
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
|
|
|
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
|
|
|
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
|
|
|
|
|
|
|
# Compute KL divergence using Schulman's approximator
|
2025-02-10 23:07:28 +08:00
|
|
|
kl_div = mx.exp(token_log_probs - ref_token_log_probs) - (token_log_probs - ref_token_log_probs) - 1
|
2025-02-04 02:37:05 +08:00
|
|
|
|
|
|
|
# Create mask for valid tokens
|
2025-01-31 06:55:34 +08:00
|
|
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
2025-02-04 02:37:05 +08:00
|
|
|
|
|
|
|
# Compute policy ratio
|
2025-02-10 23:07:28 +08:00
|
|
|
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs), dtype=mx.float32))
|
2025-02-04 02:37:05 +08:00
|
|
|
|
|
|
|
# Compute per-token loss following GRPO formula
|
2025-02-05 21:38:09 +08:00
|
|
|
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-02-09 22:41:47 +08:00
|
|
|
# Average over tokens
|
2025-02-10 00:13:05 +08:00
|
|
|
sequence_sums = per_token_loss.sum(axis=1)
|
|
|
|
sequence_lengths = length_mask.sum(axis=1)
|
|
|
|
loss = (sequence_sums / sequence_lengths).mean()
|
2025-02-04 02:37:05 +08:00
|
|
|
|
|
|
|
# Calculate mean KL divergence for metrics
|
2025-01-31 06:55:34 +08:00
|
|
|
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
2025-02-09 22:41:47 +08:00
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Collect reward metrics
|
2025-02-03 15:26:42 +08:00
|
|
|
reward_metrics = {}
|
|
|
|
for i, reward_func in enumerate(reward_funcs):
|
2025-02-04 02:47:40 +08:00
|
|
|
func_name = reward_func.__name__
|
2025-02-03 19:05:29 +08:00
|
|
|
func_rewards = mx.array(reward_func(
|
2025-02-04 02:43:49 +08:00
|
|
|
prompts=expanded_prompts,
|
2025-02-03 19:05:29 +08:00
|
|
|
completions=all_completion_texts,
|
2025-02-04 02:43:49 +08:00
|
|
|
answer=expanded_answers
|
2025-02-03 19:05:29 +08:00
|
|
|
))
|
2025-02-04 02:47:40 +08:00
|
|
|
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
2025-02-11 00:51:14 +08:00
|
|
|
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
metrics = {
|
2025-02-03 15:26:42 +08:00
|
|
|
'total_rewards_mean': mx.mean(rewards),
|
|
|
|
'total_rewards_std': mx.std(rewards),
|
2025-02-04 02:37:05 +08:00
|
|
|
'grouped_rewards_mean': mx.mean(rewards_reshaped),
|
|
|
|
'grouped_rewards_std': mx.std(rewards_reshaped),
|
2025-02-03 15:26:42 +08:00
|
|
|
'kl': mean_kl,
|
|
|
|
**reward_metrics
|
2025-01-31 06:55:34 +08:00
|
|
|
}
|
2025-02-05 15:44:06 +08:00
|
|
|
mx.metal.clear_cache()
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-02-10 00:13:05 +08:00
|
|
|
return loss, sequence_lengths.sum(), metrics
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
|
2025-02-03 17:08:28 +08:00
|
|
|
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
2025-02-03 19:05:29 +08:00
|
|
|
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
|
|
|
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Sort by length but use generator to avoid keeping full sorted list in memory
|
|
|
|
def length_key(i):
|
|
|
|
return len(dataset[i][0]) + len(dataset[i][1])
|
|
|
|
|
|
|
|
idx = sorted(range(len(dataset)), key=length_key)
|
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
if len(dataset) < batch_size:
|
|
|
|
raise ValueError(
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Dataset must have at least batch_size={batch_size} "
|
|
|
|
f"examples but only has {len(dataset)}."
|
2025-01-31 06:55:34 +08:00
|
|
|
)
|
|
|
|
|
|
|
|
step = mx.distributed.init().size()
|
|
|
|
if batch_size % step != 0:
|
|
|
|
raise ValueError("The batch size must be divisible by the number of workers")
|
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
# Use generator for batch indices
|
|
|
|
def batch_index_generator():
|
|
|
|
for i in range(0, len(idx) - batch_size + 1, batch_size):
|
|
|
|
yield idx[i : i + batch_size : step]
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
while True:
|
2025-02-04 02:37:05 +08:00
|
|
|
indices = (
|
|
|
|
np.random.permutation(list(batch_index_generator())) if train
|
|
|
|
else batch_index_generator()
|
|
|
|
)
|
2025-02-03 17:08:28 +08:00
|
|
|
|
2025-02-04 02:37:05 +08:00
|
|
|
for batch_idx in indices:
|
|
|
|
current_batch = [dataset[j] for j in batch_idx]
|
2025-02-03 17:08:28 +08:00
|
|
|
|
2025-02-03 19:05:29 +08:00
|
|
|
prompts_tokens = [item[0] for item in current_batch]
|
|
|
|
answers_tokens = [item[1] for item in current_batch]
|
|
|
|
prompts_text = [item[2] for item in current_batch]
|
|
|
|
answers_text = [item[3] for item in current_batch]
|
|
|
|
|
|
|
|
if any(len(p) > max_seq_length for p in prompts_tokens):
|
2025-01-31 06:55:34 +08:00
|
|
|
print(
|
2025-02-03 17:08:28 +08:00
|
|
|
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
|
|
|
|
"Long prompts will be truncated."
|
2025-01-31 06:55:34 +08:00
|
|
|
)
|
2025-02-03 19:05:29 +08:00
|
|
|
|
|
|
|
yield prompts_tokens, answers_tokens, prompts_text, answers_text
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
if not train:
|
|
|
|
break
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_grpo(
|
|
|
|
model,
|
|
|
|
ref_model,
|
|
|
|
dataset,
|
|
|
|
tokenizer,
|
|
|
|
batch_size,
|
|
|
|
num_batches,
|
|
|
|
beta: float,
|
2025-02-03 17:08:28 +08:00
|
|
|
epsilon: float,
|
2025-01-31 06:55:34 +08:00
|
|
|
group_size: int,
|
|
|
|
max_seq_length,
|
2025-01-31 23:54:18 +08:00
|
|
|
reward_funcs = None,
|
2025-02-04 16:18:45 +08:00
|
|
|
loss_fn: callable = grpo_loss,
|
2025-02-03 17:08:28 +08:00
|
|
|
iterate_batches: callable = iterate_grpo_batches
|
2025-01-31 06:55:34 +08:00
|
|
|
):
|
2025-02-03 19:05:29 +08:00
|
|
|
"""
|
|
|
|
Evaluate model using GRPO loss.
|
|
|
|
Returns:
|
|
|
|
tuple: (average loss, number of tokens, average metrics)
|
|
|
|
"""
|
2025-01-31 06:55:34 +08:00
|
|
|
all_losses = 0
|
|
|
|
ntokens = 0
|
2025-02-03 19:05:29 +08:00
|
|
|
all_metrics = None # Initialize metrics dictionary
|
|
|
|
|
|
|
|
# Create iterator for batches
|
2025-01-31 06:55:34 +08:00
|
|
|
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
2025-02-03 19:05:29 +08:00
|
|
|
|
|
|
|
# Iterate through batches
|
2025-01-31 06:55:34 +08:00
|
|
|
for _, batch in zip(
|
|
|
|
index_iterator,
|
|
|
|
iterate_batches(
|
|
|
|
dataset=dataset,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
batch_size=batch_size,
|
|
|
|
max_seq_length=max_seq_length,
|
|
|
|
),
|
|
|
|
):
|
2025-02-03 19:05:29 +08:00
|
|
|
# Calculate loss for current batch
|
2025-02-04 16:18:45 +08:00
|
|
|
losses, toks, metrics = loss_fn(
|
2025-01-31 23:27:31 +08:00
|
|
|
model=model,
|
|
|
|
tokenizer=tokenizer,
|
2025-02-03 19:05:29 +08:00
|
|
|
batch=batch,
|
2025-01-31 23:27:31 +08:00
|
|
|
reward_funcs=reward_funcs,
|
|
|
|
beta=beta,
|
|
|
|
group_size=group_size,
|
2025-02-03 17:08:28 +08:00
|
|
|
epsilon=epsilon,
|
2025-01-31 23:27:31 +08:00
|
|
|
ref_model=ref_model
|
|
|
|
)
|
2025-02-03 19:05:29 +08:00
|
|
|
|
|
|
|
# Accumulate losses and tokens
|
2025-01-31 23:27:31 +08:00
|
|
|
all_losses += losses * toks
|
|
|
|
ntokens += toks
|
2025-02-03 19:05:29 +08:00
|
|
|
|
|
|
|
# Accumulate metrics
|
2025-01-31 23:54:18 +08:00
|
|
|
if all_metrics is None:
|
|
|
|
all_metrics = {k: v * toks for k, v in metrics.items()}
|
|
|
|
else:
|
|
|
|
for k, v in metrics.items():
|
|
|
|
all_metrics[k] += v * toks
|
2025-02-03 19:05:29 +08:00
|
|
|
|
|
|
|
# Evaluate accumulated values
|
2025-01-31 23:27:31 +08:00
|
|
|
mx.eval(all_losses, ntokens)
|
2025-02-03 19:05:29 +08:00
|
|
|
|
|
|
|
# Aggregate across distributed workers
|
2025-01-31 23:27:31 +08:00
|
|
|
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
|
|
|
|
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
|
2025-01-31 23:54:18 +08:00
|
|
|
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
|
2025-02-03 19:05:29 +08:00
|
|
|
|
|
|
|
# Calculate averages
|
2025-01-31 23:54:18 +08:00
|
|
|
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
|
|
|
|
avg_loss = (all_losses / ntokens).item()
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
return avg_loss, ntokens, avg_metrics
|
2025-01-31 23:27:31 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
|
2025-01-31 23:57:43 +08:00
|
|
|
def train_grpo(
|
2025-01-31 06:55:34 +08:00
|
|
|
model,
|
2025-02-03 16:13:17 +08:00
|
|
|
ref_model,
|
2025-01-31 06:55:34 +08:00
|
|
|
tokenizer,
|
|
|
|
optimizer,
|
|
|
|
train_dataset,
|
|
|
|
val_dataset,
|
2025-02-03 16:13:17 +08:00
|
|
|
reward_funcs = [
|
|
|
|
r1_accuracy_reward_func,
|
|
|
|
r1_int_reward_func,
|
|
|
|
r1_strict_format_reward_func,
|
|
|
|
r1_soft_format_reward_func,
|
|
|
|
r1_count_xml
|
|
|
|
],
|
2025-01-31 06:55:34 +08:00
|
|
|
args: GRPOTrainingArgs = GRPOTrainingArgs(),
|
2025-02-04 16:18:45 +08:00
|
|
|
loss_fn: callable = grpo_loss,
|
2025-02-03 17:08:28 +08:00
|
|
|
iterate_batches: callable = iterate_grpo_batches,
|
2025-01-31 06:55:34 +08:00
|
|
|
training_callback: TrainingCallback = None,
|
|
|
|
):
|
2025-02-03 17:08:28 +08:00
|
|
|
print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}")
|
2025-01-31 06:55:34 +08:00
|
|
|
world = mx.distributed.init()
|
|
|
|
world_size = world.size()
|
|
|
|
rank = world.rank()
|
|
|
|
if world_size > 1:
|
|
|
|
print(f"Node {rank} of {world_size}")
|
|
|
|
|
|
|
|
if args.grad_checkpoint:
|
|
|
|
grad_checkpoint(model.layers[0])
|
|
|
|
|
|
|
|
state = [model.state, optimizer.state]
|
|
|
|
|
|
|
|
def step(batch):
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
# Forward and backward pass
|
2025-02-03 19:05:29 +08:00
|
|
|
(loss, toks, metrics), grad = loss_value_and_grad(
|
|
|
|
model,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
batch=batch,
|
|
|
|
reward_funcs=reward_funcs,
|
|
|
|
beta=args.beta,
|
|
|
|
group_size=args.group_size,
|
|
|
|
epsilon=args.epsilon,
|
2025-02-04 02:37:05 +08:00
|
|
|
ref_model=ref_model,
|
2025-02-04 16:18:45 +08:00
|
|
|
max_tokens=args.max_completion_length,
|
2025-02-03 19:05:29 +08:00
|
|
|
)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
# All reduce the gradients if running in distributed mode
|
|
|
|
grad = average_gradients(grad)
|
|
|
|
|
|
|
|
# Model update
|
|
|
|
optimizer.update(model, grad)
|
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
return loss, toks, metrics
|
2025-02-03 19:05:29 +08:00
|
|
|
|
2025-02-04 16:18:45 +08:00
|
|
|
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
losses = 0
|
|
|
|
n_tokens = 0
|
|
|
|
steps = 0
|
|
|
|
trained_tokens = 0
|
2025-01-31 23:54:18 +08:00
|
|
|
accumulated_metrics = {
|
2025-02-04 16:18:45 +08:00
|
|
|
'total_rewards_mean': 0,
|
|
|
|
'total_rewards_std': 0,
|
|
|
|
'grouped_rewards_mean': 0,
|
2025-01-31 23:54:18 +08:00
|
|
|
'grouped_rewards_std': 0,
|
|
|
|
'kl': 0
|
|
|
|
}
|
2025-02-04 16:18:45 +08:00
|
|
|
for reward_func in reward_funcs:
|
|
|
|
func_name = reward_func.__name__
|
|
|
|
accumulated_metrics[f'{func_name}_mean'] = 0
|
|
|
|
accumulated_metrics[f'{func_name}_std'] = 0
|
2025-01-31 23:54:18 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
start = time.perf_counter()
|
|
|
|
for it, batch in zip(
|
|
|
|
range(1, args.iters + 1),
|
|
|
|
iterate_batches(
|
|
|
|
dataset=train_dataset,
|
|
|
|
tokenizer=tokenizer,
|
|
|
|
batch_size=args.batch_size,
|
|
|
|
max_seq_length=args.max_seq_length,
|
|
|
|
train=True,
|
|
|
|
),
|
|
|
|
):
|
|
|
|
# Report validation loss if needed, the first validation loss
|
|
|
|
# is always measured before any training.
|
|
|
|
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
|
|
|
stop = time.perf_counter()
|
2025-01-31 23:57:43 +08:00
|
|
|
val_loss, val_ntokens, val_metrics = evaluate_grpo(
|
2025-01-31 06:55:34 +08:00
|
|
|
model=model,
|
|
|
|
dataset=val_dataset,
|
2025-02-04 16:18:45 +08:00
|
|
|
loss_fn=loss_fn,
|
2025-02-03 17:08:28 +08:00
|
|
|
ref_model=ref_model,
|
2025-02-03 16:13:17 +08:00
|
|
|
reward_funcs=reward_funcs,
|
2025-01-31 06:55:34 +08:00
|
|
|
tokenizer=tokenizer,
|
2025-02-03 17:08:28 +08:00
|
|
|
group_size=args.group_size,
|
2025-01-31 06:55:34 +08:00
|
|
|
batch_size=args.batch_size,
|
|
|
|
num_batches=args.val_batches,
|
|
|
|
max_seq_length=args.max_seq_length,
|
2025-02-03 17:08:28 +08:00
|
|
|
beta=args.beta,
|
|
|
|
epsilon=args.epsilon,
|
2025-01-31 06:55:34 +08:00
|
|
|
iterate_batches=iterate_batches,
|
|
|
|
)
|
|
|
|
val_time = time.perf_counter() - stop
|
|
|
|
if rank == 0:
|
2025-02-03 17:08:28 +08:00
|
|
|
val_metrics_str = (
|
2025-01-31 23:54:18 +08:00
|
|
|
f"Val loss {val_loss:.8f}, "
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
|
|
|
|
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
|
|
|
|
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
|
2025-01-31 23:54:18 +08:00
|
|
|
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Val kl {val_metrics['kl']:.3f}"
|
|
|
|
)
|
|
|
|
|
|
|
|
# Add reward function specific metrics
|
2025-02-04 02:56:11 +08:00
|
|
|
for i, reward_func in enumerate(reward_funcs):
|
2025-02-03 17:08:28 +08:00
|
|
|
val_metrics_str += (
|
2025-02-04 02:56:11 +08:00
|
|
|
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
2025-02-09 22:41:47 +08:00
|
|
|
# f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
2025-02-03 17:08:28 +08:00
|
|
|
)
|
2025-02-04 02:56:11 +08:00
|
|
|
|
2025-02-03 17:08:28 +08:00
|
|
|
print(
|
|
|
|
f"Iter {it}: {val_metrics_str}, "
|
2025-01-31 06:55:34 +08:00
|
|
|
f"Val took {val_time:.3f}s",
|
|
|
|
flush=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
if training_callback is not None:
|
2025-01-31 23:54:18 +08:00
|
|
|
training_callback.on_val_loss_report({
|
2025-01-31 06:55:34 +08:00
|
|
|
"iteration": it,
|
|
|
|
"val_loss": val_loss,
|
2025-01-31 23:54:18 +08:00
|
|
|
**{f"val_{k}": v for k, v in val_metrics.items()},
|
2025-01-31 06:55:34 +08:00
|
|
|
"val_time": val_time,
|
2025-01-31 23:54:18 +08:00
|
|
|
})
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
start = time.perf_counter()
|
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
loss, toks, metrics = step(batch)
|
|
|
|
losses += loss
|
2025-01-31 06:55:34 +08:00
|
|
|
n_tokens += toks
|
|
|
|
steps += 1
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-01-31 23:54:18 +08:00
|
|
|
for k, v in metrics.items():
|
|
|
|
accumulated_metrics[k] += v
|
2025-02-04 02:37:05 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
mx.eval(state, losses, n_tokens)
|
|
|
|
|
|
|
|
if it % args.steps_per_report == 0 or it == args.iters:
|
|
|
|
stop = time.perf_counter()
|
|
|
|
|
|
|
|
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
|
|
|
train_loss /= steps * mx.distributed.init().size()
|
2025-01-31 23:54:18 +08:00
|
|
|
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
|
2025-01-31 06:55:34 +08:00
|
|
|
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
|
|
|
learning_rate = optimizer.learning_rate.item()
|
|
|
|
it_sec = args.steps_per_report / (stop - start)
|
|
|
|
tokens_sec = float(n_tokens) / (stop - start)
|
|
|
|
trained_tokens += n_tokens
|
|
|
|
peak_mem = mx.metal.get_peak_memory() / 1e9
|
2025-01-31 23:54:18 +08:00
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
if rank == 0:
|
2025-02-03 17:08:28 +08:00
|
|
|
train_metrics_str = (
|
|
|
|
f"Train loss {train_loss:.8f}, "
|
|
|
|
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
|
|
|
|
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
|
|
|
|
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
|
|
|
|
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
|
|
|
|
f"KL {avg_metrics['kl']:.3f}"
|
|
|
|
)
|
|
|
|
|
|
|
|
# Add reward function specific metrics
|
2025-02-04 02:56:11 +08:00
|
|
|
for i, reward_func in enumerate(reward_funcs):
|
|
|
|
func_name = reward_func.__name__
|
2025-02-03 17:08:28 +08:00
|
|
|
train_metrics_str += (
|
2025-02-04 16:18:45 +08:00
|
|
|
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
|
|
|
|
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
|
2025-02-03 17:08:28 +08:00
|
|
|
)
|
|
|
|
|
2025-01-31 06:55:34 +08:00
|
|
|
print(
|
2025-02-03 17:08:28 +08:00
|
|
|
f"Iter {it}: {train_metrics_str}, "
|
2025-01-31 06:55:34 +08:00
|
|
|
f"Learning Rate {learning_rate:.3e}, "
|
|
|
|
f"It/sec {it_sec:.3f}, "
|
|
|
|
f"Tokens/sec {tokens_sec:.3f}, "
|
|
|
|
f"Peak mem {peak_mem:.3f} GB",
|
|
|
|
flush=True,
|
|
|
|
)
|
|
|
|
|
|
|
|
if training_callback is not None:
|
2025-01-31 23:54:18 +08:00
|
|
|
training_callback.on_train_loss_report({
|
2025-01-31 06:55:34 +08:00
|
|
|
"iteration": it,
|
|
|
|
"train_loss": train_loss,
|
2025-01-31 23:54:18 +08:00
|
|
|
**{f"train_{k}": v for k, v in avg_metrics.items()},
|
2025-01-31 06:55:34 +08:00
|
|
|
"learning_rate": learning_rate,
|
|
|
|
"iterations_per_second": it_sec,
|
|
|
|
"tokens_per_second": tokens_sec,
|
|
|
|
"trained_tokens": trained_tokens,
|
|
|
|
"peak_memory": peak_mem,
|
2025-01-31 23:54:18 +08:00
|
|
|
})
|
2025-01-31 06:55:34 +08:00
|
|
|
|
|
|
|
losses = 0
|
|
|
|
n_tokens = 0
|
|
|
|
steps = 0
|
|
|
|
start = time.perf_counter()
|
|
|
|
|
|
|
|
# Save adapter weights
|
|
|
|
if it % args.steps_per_save == 0:
|
|
|
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
|
|
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
|
|
|
checkpoint = (
|
|
|
|
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
|
|
|
)
|
|
|
|
mx.save_safetensors(str(checkpoint), adapter_weights)
|
|
|
|
print(
|
|
|
|
f"Iter {it}: Saved adapter weights to "
|
|
|
|
f"{args.adapter_file} and {checkpoint}."
|
|
|
|
)
|
|
|
|
|
|
|
|
# Save final weights
|
|
|
|
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
|
|
|
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
2025-01-31 23:57:43 +08:00
|
|
|
print(f"Saved final weights to {args.adapter_file}.")
|