mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 02:33:23 +08:00
removing helper functions
This commit is contained in:
parent
d9da35f458
commit
f88e897019
@ -180,15 +180,6 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
return per_token_logps
|
||||
|
||||
|
||||
def compute_kl(logprobs1, logprobs2):
|
||||
ratio = mx.exp(logprobs1 - logprobs2)
|
||||
return ratio - 1 - (logprobs1 - logprobs2)
|
||||
|
||||
|
||||
def compute_policy_ratio(current_logprobs, ref_logprobs):
|
||||
return mx.exp(mx.array(current_logprobs - mx.stop_gradient(ref_logprobs), dtype=mx.float32))
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
ref_model,
|
||||
@ -302,13 +293,13 @@ def grpo_loss(
|
||||
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||
|
||||
# Compute KL divergence using Schulman's approximator
|
||||
kl_div = compute_kl(token_log_probs, ref_token_log_probs)
|
||||
kl_div = mx.exp(token_log_probs - ref_token_log_probs) - (token_log_probs - ref_token_log_probs) - 1
|
||||
|
||||
# Create mask for valid tokens
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = compute_policy_ratio(token_log_probs, ref_token_log_probs)
|
||||
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs), dtype=mx.float32))
|
||||
|
||||
# Compute per-token loss following GRPO formula
|
||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||
|
Loading…
Reference in New Issue
Block a user