This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 19:37:05 +01:00
parent 1d9e4802f0
commit 05d921b788

View File

@ -35,68 +35,50 @@ class GRPOTrainingArgs(TrainingArgs):
) )
def generate_for_grpo( def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
model, model.eval()
prompt, if len(prompt.shape) == 1:
max_tokens, prompt = prompt[None, :]
tokenizer,
temperature=1.0
):
try:
# Ensure prompt is the right shape generated = []
if len(prompt.shape) == 1: current_prompt = prompt[0]
prompt = prompt[None, :]
# Initialize generation for _ in range(max_tokens):
generated = [] current_batch = current_prompt[None, :]
current_prompt = prompt[0] logits = model(current_batch)
token_logits = logits[0, -1]
for step in range(max_tokens): if temperature > 0:
try: token_logits = token_logits / temperature
# Get model output with explicit shape checking
current_batch = current_prompt[None, :]
logits = model(current_batch) probs = mx.softmax(token_logits)
next_token = mx.random.categorical(probs[None, :])
next_token = next_token[0]
mx.eval(next_token)
# Ensure we have the last token logits token_value = next_token.item()
token_logits = logits[0, -1] generated.append(next_token)
# Apply temperature and get probabilities # Clear intermediate tensors
if temperature > 0: del logits, token_logits, probs
token_logits = token_logits / temperature mx.metal.clear_cache()
probs = mx.softmax(token_logits)
# Sample the next token current_prompt = mx.concatenate([current_prompt, next_token[None]])
next_token = mx.random.categorical(probs[None, :]) if token_value == tokenizer.eos_token_id:
next_token = next_token[0] break
# Force evaluation to catch any issues if not generated:
mx.eval(next_token) return prompt[0]
token_value = next_token.item()
# Add to generated sequence result = mx.concatenate([prompt[0], mx.stack(generated)])
generated.append(next_token) mx.eval(result)
current_prompt = mx.concatenate([current_prompt, next_token[None]]) model.train()
if token_value == tokenizer.eos_token_id: # Clear generated tokens
break del generated
mx.metal.clear_cache()
except Exception as e: return result
raise
if not generated:
return prompt[0]
try:
result = mx.concatenate([prompt[0], mx.stack(generated)])
mx.eval(result)
return result
except Exception as e:
raise
except Exception as e:
raise
def r1_extract_xml_answer(text: str) -> str: def r1_extract_xml_answer(text: str) -> str:
@ -191,67 +173,46 @@ def grpo_loss(
group_size=4, group_size=4,
epsilon=1e-4, epsilon=1e-4,
ref_model=None, ref_model=None,
max_tokens=128, max_tokens=64,
temperature=1.0 temperature=1.0
): ):
"""Modified GRPO loss function with better error handling"""
prompt_tokens, answer_tokens, prompt_text, answer_text = batch prompt_tokens, answer_tokens, prompt_text, answer_text = batch
batch_size = len(prompt_tokens) batch_size = len(prompt_tokens)
# Generate completions for each prompt # Generation logic remains the same
all_completions = [] all_completions = []
all_completion_texts = [] all_completion_texts = []
for prompt in prompt_tokens: for prompt in prompt_tokens:
prompt_tensor = mx.array(prompt) prompt_tensor = mx.array(prompt)
prompt_completions = []
prompt_completion_texts = []
# Generate group_size completions for each prompt
for _ in range(group_size): for _ in range(group_size):
try: try:
completion_ids = generate_for_grpo( completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature)
model,
prompt_tensor,
max_tokens,
tokenizer=tokenizer,
temperature=temperature
)
# Verify completion_ids is not None
if completion_ids is None: if completion_ids is None:
print("Warning: generate_for_grpo returned None") continue
break
completion_text = tokenizer.decode(completion_ids.tolist()) completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
prompt_completions.append(completion_ids) del completion_ids
prompt_completion_texts.append(completion_text) mx.metal.clear_cache()
except Exception as e: except Exception as e:
print(f"Error in completion generation: {str(e)}") print(f"Generation error: {e}")
# Fallback to using original prompt continue
prompt_completions.append(prompt_tensor)
prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist()))
all_completions.extend(prompt_completions) del prompt_tensor
all_completion_texts.extend(prompt_completion_texts) mx.metal.clear_cache()
# Verify we have the expected number of completions # Prepare inputs
assert len(all_completions) == batch_size * group_size
assert len(all_completion_texts) == batch_size * group_size
# Expand answer_text and prompt_text to match completion groups
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []
for i in range(batch_size): for i in range(batch_size):
expanded_answers.extend([answer_text[i]] * group_size) expanded_answers.extend([answer_text[i]] * group_size)
expanded_prompts.extend([prompt_text[i]] * group_size) expanded_prompts.extend([prompt_text[i]] * group_size)
# Verify we have the expected number of completions
assert len(all_completions) == batch_size * group_size
assert len(all_completion_texts) == batch_size * group_size
max_length = max(ids.shape[0] for ids in all_completions) max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = [] padded_completions = []
attention_masks = [] attention_masks = []
@ -268,31 +229,36 @@ def grpo_loss(
padded_completions.append(padded_ids) padded_completions.append(padded_ids)
attention_masks.append(mask) attention_masks.append(mask)
del completion_ids
if padding_length > 0:
del padding
del mask
mx.metal.clear_cache()
inputs = mx.stack(padded_completions) inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks) attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1) lengths = attention_mask.sum(axis=1)
# Get logits from current model del padded_completions, attention_masks
mx.metal.clear_cache()
# Get logits and compute log probabilities
logits = model(inputs).astype(mx.float32) logits = model(inputs).astype(mx.float32)
# Calculate log probabilities
log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1) log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1)
# Prepare targets
targets = inputs[:, 1:] targets = inputs[:, 1:]
# Gather actual token probabilities # Current policy probabilities
token_log_probs = mx.take_along_axis( token_log_probs = mx.take_along_axis(
log_probs, log_probs,
targets.reshape(*targets.shape, 1), targets.reshape(*targets.shape, 1),
axis=-1 axis=-1
).squeeze(-1) ).squeeze(-1)
# Get reference model log probabilities # Reference policy probabilities
if ref_model is not None: if ref_model is not None:
ref_logits = ref_model(inputs).astype(mx.float32) ref_logits = ref_model(inputs).astype(mx.float32)
else: else:
ref_logits = model(inputs).astype(mx.float32) ref_logits = mx.array(logits)
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1) ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
ref_token_log_probs = mx.take_along_axis( ref_token_log_probs = mx.take_along_axis(
@ -301,10 +267,7 @@ def grpo_loss(
axis=-1 axis=-1
).squeeze(-1) ).squeeze(-1)
# Compute KL divergence # Calculate rewards and advantages
kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1)
# Calculate combined rewards from all reward functions
rewards = mx.zeros((len(all_completions),)) rewards = mx.zeros((len(all_completions),))
for reward_func in reward_funcs: for reward_func in reward_funcs:
func_rewards = mx.array(reward_func( func_rewards = mx.array(reward_func(
@ -314,36 +277,36 @@ def grpo_loss(
)) ))
rewards += func_rewards rewards += func_rewards
# Normalize rewards if using multiple reward functions
if len(reward_funcs) > 1: if len(reward_funcs) > 1:
rewards /= len(reward_funcs) rewards /= len(reward_funcs)
# Compute grouped-wise rewards # Reshape rewards and compute advantages following GRPO formula
grouped_rewards = rewards.reshape(batch_size, group_size) rewards_reshaped = rewards.reshape(batch_size, group_size)
mean_grouped_rewards = mx.mean(grouped_rewards, axis=1) mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
std_grouped_rewards = mx.std(grouped_rewards, axis=1) std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
# Normalize rewards to compute advantages # Compute KL divergence using Schulman's approximator
mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1)
advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon)
# Create length mask for the shifted sequence # Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Calculate policy gradient loss # Compute policy ratio
per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1) policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs))
per_token_loss = -(per_token_loss - beta * kl_div)
# Normalize loss properly per sequence # Compute per-token loss following GRPO formula
per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div)
# Average over tokens and sequences
sequence_sums = (per_token_loss * length_mask).sum(axis=1) sequence_sums = (per_token_loss * length_mask).sum(axis=1)
sequence_lengths = length_mask.sum(axis=1) sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean() loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence # Calculate mean KL divergence for metrics
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
# Collect metrics for each reward function separately # Collect reward metrics
reward_metrics = {} reward_metrics = {}
for i, reward_func in enumerate(reward_funcs): for i, reward_func in enumerate(reward_funcs):
func_rewards = mx.array(reward_func( func_rewards = mx.array(reward_func(
@ -351,15 +314,18 @@ def grpo_loss(
completions=all_completion_texts, completions=all_completion_texts,
answer=answer_text answer=answer_text
)) ))
# func_grouped_rewards = func_rewards.reshape(batch_size, group_size)
reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards)
reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards)
# Clean up
del all_completions
mx.metal.clear_cache()
metrics = { metrics = {
'total_rewards_mean': mx.mean(rewards), 'total_rewards_mean': mx.mean(rewards),
'total_rewards_std': mx.std(rewards), 'total_rewards_std': mx.std(rewards),
'grouped_rewards_mean': mx.mean(grouped_rewards), 'grouped_rewards_mean': mx.mean(rewards_reshaped),
'grouped_rewards_std': mx.std(grouped_rewards), 'grouped_rewards_std': mx.std(rewards_reshaped),
'kl': mean_kl, 'kl': mean_kl,
**reward_metrics **reward_metrics
} }
@ -368,30 +334,15 @@ def grpo_loss(
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
""" """Memory-optimized version of iterate_grpo_batches"""
Creates batches from dataset entries for GRPO training.
Args:
dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples
tokenizer: Tokenizer for processing inputs
batch_size: Size of each batch
max_seq_length: Maximum sequence length
train: Whether this is for training
Yields:
Tuple containing:
- prompts_tokens: List of token sequences for current batch
- answers_tokens: List of token sequences
- prompts_text: List of prompt strings
- answers_text: List of answer strings
"""
# Verify dataset format
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
# Sort by combined length of prompt + answer tokens # Sort by length but use generator to avoid keeping full sorted list in memory
idx = sorted(range(len(dataset)), def length_key(i):
key=lambda i: len(dataset[i][0]) + len(dataset[i][1])) return len(dataset[i][0]) + len(dataset[i][1])
idx = sorted(range(len(dataset)), key=length_key)
if len(dataset) < batch_size: if len(dataset) < batch_size:
raise ValueError( raise ValueError(
@ -399,26 +350,24 @@ def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F
f"examples but only has {len(dataset)}." f"examples but only has {len(dataset)}."
) )
# Handle distributed training
step = mx.distributed.init().size() step = mx.distributed.init().size()
if batch_size % step != 0: if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers") raise ValueError("The batch size must be divisible by the number of workers")
# Create batch indices # Use generator for batch indices
batch_idx = [ def batch_index_generator():
idx[i : i + batch_size : step] for i in range(0, len(idx) - batch_size + 1, batch_size):
for i in range(0, len(idx) - batch_size + 1, batch_size) yield idx[i : i + batch_size : step]
]
while True: while True:
# Shuffle batch indices if training indices = (
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) np.random.permutation(list(batch_index_generator())) if train
else batch_index_generator()
)
for i in indices: for batch_idx in indices:
# Get current batch current_batch = [dataset[j] for j in batch_idx]
current_batch = [dataset[j] for j in batch_idx[i]]
# Extract all components
prompts_tokens = [item[0] for item in current_batch] prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch] answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch] prompts_text = [item[2] for item in current_batch]
@ -553,7 +502,8 @@ def train_grpo(
beta=args.beta, beta=args.beta,
group_size=args.group_size, group_size=args.group_size,
epsilon=args.epsilon, epsilon=args.epsilon,
ref_model=ref_model ref_model=ref_model,
max_tokens=args.max_seq_length,
) )
# All reduce the gradients if running in distributed mode # All reduce the gradients if running in distributed mode
@ -649,8 +599,10 @@ def train_grpo(
losses += loss losses += loss
n_tokens += toks n_tokens += toks
steps += 1 steps += 1
for k, v in metrics.items(): for k, v in metrics.items():
accumulated_metrics[k] += v accumulated_metrics[k] += v
mx.eval(state, losses, n_tokens) mx.eval(state, losses, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters: