removing helper functions

This commit is contained in:
Goekdeniz-Guelmez 2025-02-10 16:07:28 +01:00
parent d9da35f458
commit f88e897019

View File

@ -180,15 +180,6 @@ def get_per_token_logps(model, inputs, lengths):
return per_token_logps
def compute_kl(logprobs1, logprobs2):
ratio = mx.exp(logprobs1 - logprobs2)
return ratio - 1 - (logprobs1 - logprobs2)
def compute_policy_ratio(current_logprobs, ref_logprobs):
return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32))
def grpo_loss(
model,
ref_model,
@ -302,13 +293,13 @@ def grpo_loss(
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
# Compute KL divergence using Schulman's approximator
kl_div = compute_kl(token_log_probs, ref_token_log_probs)
kl_div = mx.exp(token_log_probs - ref_token_log_probs) - (token_log_probs - ref_token_log_probs) - 1
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = compute_policy_ratio(token_log_probs, ref_token_log_probs)
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs), dtype=mx.float32))
# Compute per-token loss following GRPO formula
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)