mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
removing print and switching some variables in the math
This commit is contained in:
parent
5ec4790656
commit
6a6bd53e43
@ -280,7 +280,6 @@ def grpo_loss(
|
||||
|
||||
# Stack rewards to shape (num_samples, num_funcs)
|
||||
rewards = mx.stack(all_func_rewards, axis=1)
|
||||
print(f"Rewards: {rewards}")
|
||||
|
||||
# Apply weights and sum
|
||||
if reward_weights is not None:
|
||||
@ -293,7 +292,6 @@ def grpo_loss(
|
||||
else:
|
||||
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||
print(f"Rewards after weights: {rewards}")
|
||||
|
||||
# Reshape rewards and compute advantages
|
||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||
@ -302,7 +300,7 @@ def grpo_loss(
|
||||
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||
|
||||
# Compute KL divergence using Schulman's approximator
|
||||
kl_div = mx.exp(token_log_probs - ref_token_log_probs) - (token_log_probs - ref_token_log_probs) - 1
|
||||
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
|
||||
|
||||
# Create mask for valid tokens
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
Loading…
Reference in New Issue
Block a user