removing print and switching some variables in the math

This commit is contained in:
Goekdeniz-Guelmez 2025-02-15 15:38:51 +01:00
parent 5ec4790656
commit 6a6bd53e43

View File

@ -280,7 +280,6 @@ def grpo_loss(
# Stack rewards to shape (num_samples, num_funcs) # Stack rewards to shape (num_samples, num_funcs)
rewards = mx.stack(all_func_rewards, axis=1) rewards = mx.stack(all_func_rewards, axis=1)
print(f"Rewards: {rewards}")
# Apply weights and sum # Apply weights and sum
if reward_weights is not None: if reward_weights is not None:
@ -293,7 +292,6 @@ def grpo_loss(
else: else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32) reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
print(f"Rewards after weights: {rewards}")
# Reshape rewards and compute advantages # Reshape rewards and compute advantages
rewards_reshaped = rewards.reshape(batch_size, group_size) rewards_reshaped = rewards.reshape(batch_size, group_size)
@ -302,7 +300,7 @@ def grpo_loss(
advantages = (rewards - mean_rewards) / (std_rewards + epsilon) advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
# Compute KL divergence using Schulman's approximator # Compute KL divergence using Schulman's approximator
kl_div = mx.exp(token_log_probs - ref_token_log_probs) - (token_log_probs - ref_token_log_probs) - 1 kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
# Create mask for valid tokens # Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)