This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 19:37:05 +01:00
parent 1d9e4802f0
commit 05d921b788

View File

@ -35,68 +35,50 @@ class GRPOTrainingArgs(TrainingArgs):
)
def generate_for_grpo(
model,
prompt,
max_tokens,
tokenizer,
temperature=1.0
):
try:
def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
model.eval()
if len(prompt.shape) == 1:
prompt = prompt[None, :]
generated = []
current_prompt = prompt[0]
for _ in range(max_tokens):
current_batch = current_prompt[None, :]
logits = model(current_batch)
token_logits = logits[0, -1]
# Ensure prompt is the right shape
if len(prompt.shape) == 1:
prompt = prompt[None, :]
# Initialize generation
generated = []
current_prompt = prompt[0]
for step in range(max_tokens):
try:
# Get model output with explicit shape checking
current_batch = current_prompt[None, :]
logits = model(current_batch)
# Ensure we have the last token logits
token_logits = logits[0, -1]
# Apply temperature and get probabilities
if temperature > 0:
token_logits = token_logits / temperature
probs = mx.softmax(token_logits)
# Sample the next token
next_token = mx.random.categorical(probs[None, :])
next_token = next_token[0]
# Force evaluation to catch any issues
mx.eval(next_token)
token_value = next_token.item()
# Add to generated sequence
generated.append(next_token)
current_prompt = mx.concatenate([current_prompt, next_token[None]])
if token_value == tokenizer.eos_token_id:
break
except Exception as e:
raise
if not generated:
return prompt[0]
if temperature > 0:
token_logits = token_logits / temperature
try:
result = mx.concatenate([prompt[0], mx.stack(generated)])
mx.eval(result)
return result
except Exception as e:
raise
probs = mx.softmax(token_logits)
next_token = mx.random.categorical(probs[None, :])
next_token = next_token[0]
mx.eval(next_token)
token_value = next_token.item()
generated.append(next_token)
# Clear intermediate tensors
del logits, token_logits, probs
mx.metal.clear_cache()
current_prompt = mx.concatenate([current_prompt, next_token[None]])
if token_value == tokenizer.eos_token_id:
break
except Exception as e:
raise
if not generated:
return prompt[0]
result = mx.concatenate([prompt[0], mx.stack(generated)])
mx.eval(result)
model.train()
# Clear generated tokens
del generated
mx.metal.clear_cache()
return result
def r1_extract_xml_answer(text: str) -> str:
@ -191,67 +173,46 @@ def grpo_loss(
group_size=4,
epsilon=1e-4,
ref_model=None,
max_tokens=128,
max_tokens=64,
temperature=1.0
):
"""Modified GRPO loss function with better error handling"""
prompt_tokens, answer_tokens, prompt_text, answer_text = batch
batch_size = len(prompt_tokens)
# Generate completions for each prompt
# Generation logic remains the same
all_completions = []
all_completion_texts = []
for prompt in prompt_tokens:
prompt_tensor = mx.array(prompt)
prompt_completions = []
prompt_completion_texts = []
# Generate group_size completions for each prompt
for _ in range(group_size):
try:
completion_ids = generate_for_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer=tokenizer,
temperature=temperature
)
# Verify completion_ids is not None
completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
if completion_ids is None:
print("Warning: generate_for_grpo returned None")
break
continue
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
prompt_completions.append(completion_ids)
prompt_completion_texts.append(completion_text)
del completion_ids
mx.metal.clear_cache()
except Exception as e:
print(f"Error in completion generation: {str(e)}")
# Fallback to using original prompt
prompt_completions.append(prompt_tensor)
prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist()))
print(f"Generation error: {e}")
continue
all_completions.extend(prompt_completions)
all_completion_texts.extend(prompt_completion_texts)
del prompt_tensor
mx.metal.clear_cache()
# Verify we have the expected number of completions
assert len(all_completions) == batch_size * group_size
assert len(all_completion_texts) == batch_size * group_size
# Expand answer_text and prompt_text to match completion groups
# Prepare inputs
expanded_answers = []
expanded_prompts = []
for i in range(batch_size):
expanded_answers.extend([answer_text[i]] * group_size)
expanded_prompts.extend([prompt_text[i]] * group_size)
# Verify we have the expected number of completions
assert len(all_completions) == batch_size * group_size
assert len(all_completion_texts) == batch_size * group_size
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
attention_masks = []
@ -267,32 +228,37 @@ def grpo_loss(
mask = mx.ones_like(completion_ids)
padded_completions.append(padded_ids)
attention_masks.append(mask)
del completion_ids
if padding_length > 0:
del padding
del mask
mx.metal.clear_cache()
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
# Get logits from current model
del padded_completions, attention_masks
mx.metal.clear_cache()
# Get logits and compute log probabilities
logits = model(inputs).astype(mx.float32)
# Calculate log probabilities
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
# Prepare targets
targets = inputs[:, 1:]
# Gather actual token probabilities
# Current policy probabilities
token_log_probs = mx.take_along_axis(
log_probs,
targets.reshape(*targets.shape, 1),
axis=-1
).squeeze(-1)
# Get reference model log probabilities
# Reference policy probabilities
if ref_model is not None:
ref_logits = ref_model(inputs).astype(mx.float32)
else:
ref_logits = model(inputs).astype(mx.float32)
ref_logits = mx.array(logits)
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
ref_token_log_probs = mx.take_along_axis(
@ -301,124 +267,107 @@ def grpo_loss(
axis=-1
).squeeze(-1)
# Compute KL divergence
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
# Calculate combined rewards from all reward functions
# Calculate rewards and advantages
rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs:
func_rewards = mx.array(reward_func(
prompts=prompt_text,
completions=all_completion_texts,
answer=answer_text
))
rewards += func_rewards
# Normalize rewards if using multiple reward functions
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
# Compute grouped-wise rewards
grouped_rewards = rewards.reshape(batch_size, group_size)
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1)
std_grouped_rewards = mx.std(grouped_rewards, axis=1)
# Normalize rewards to compute advantages
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
# Create length mask for the shifted sequence
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Calculate policy gradient loss
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1)
per_token_loss = -(per_token_loss - beta * kl_div)
# Normalize loss properly per sequence
sequence_sums = (per_token_loss * length_mask).sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
# Collect metrics for each reward function separately
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
func_rewards = mx.array(reward_func(
prompts=prompt_text,
prompts=prompt_text,
completions=all_completion_texts,
answer=answer_text
))
rewards += func_rewards
if len(reward_funcs) > 1:
rewards /= len(reward_funcs)
# Reshape rewards and compute advantages following GRPO formula
rewards_reshaped = rewards.reshape(batch_size, group_size)
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
# Compute KL divergence using Schulman's approximator
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
# Compute per-token loss following GRPO formula
per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div)
# Average over tokens and sequences
sequence_sums = (per_token_loss * length_mask).sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence for metrics
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
# Collect reward metrics
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
func_rewards = mx.array(reward_func(
prompts=prompt_text,
completions=all_completion_texts,
answer=answer_text
))
# func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
# Clean up
del all_completions
mx.metal.clear_cache()
metrics = {
'total_rewards_mean': mx.mean(rewards),
'total_rewards_std': mx.std(rewards),
'grouped_rewards_mean': mx.mean(grouped_rewards),
'grouped_rewards_std': mx.std(grouped_rewards),
'grouped_rewards_mean': mx.mean(rewards_reshaped),
'grouped_rewards_std': mx.std(rewards_reshaped),
'kl': mean_kl,
**reward_metrics
}
return loss, sequence_lengths.sum(), metrics
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Creates batches from dataset entries for GRPO training.
Args:
dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples
tokenizer: Tokenizer for processing inputs
batch_size: Size of each batch
max_seq_length: Maximum sequence length
train: Whether this is for training
Yields:
Tuple containing:
- prompts_tokens: List of token sequences for current batch
- answers_tokens: List of token sequences
- prompts_text: List of prompt strings
- answers_text: List of answer strings
"""
# Verify dataset format
"""Memory-optimized version of iterate_grpo_batches"""
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
# Sort by combined length of prompt + answer tokens
idx = sorted(range(len(dataset)),
key=lambda i: len(dataset[i][0]) + len(dataset[i][1]))
# Sort by length but use generator to avoid keeping full sorted list in memory
def length_key(i):
return len(dataset[i][0]) + len(dataset[i][1])
idx = sorted(range(len(dataset)), key=length_key)
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
# Handle distributed training
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
# Create batch indices
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
# Use generator for batch indices
def batch_index_generator():
for i in range(0, len(idx) - batch_size + 1, batch_size):
yield idx[i : i + batch_size : step]
while True:
# Shuffle batch indices if training
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
indices = (
np.random.permutation(list(batch_index_generator())) if train
else batch_index_generator()
)
for i in indices:
# Get current batch
current_batch = [dataset[j] for j in batch_idx[i]]
for batch_idx in indices:
current_batch = [dataset[j] for j in batch_idx]
# Extract all components
prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch]
@ -553,7 +502,8 @@ def train_grpo(
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
ref_model=ref_model
ref_model=ref_model,
max_tokens=args.max_seq_length,
)
# All reduce the gradients if running in distributed mode
@ -649,8 +599,10 @@ def train_grpo(
losses += loss
n_tokens += toks
steps += 1
for k, v in metrics.items():
accumulated_metrics[k] += v
mx.eval(state, losses, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters: