seperate functions

This commit is contained in:
Goekdeniz-Guelmez
2025-03-05 15:28:12 +01:00
parent d723ddfeda
commit f13a0d04ca
2 changed files with 269 additions and 72 deletions

View File

@@ -453,4 +453,177 @@ def main():
if __name__ == "__main__":
main()
main()
def compute_grpo_loss_and_grad(
model,
ref_model,
completion_tensors,
prompt_texts,
answer_texts,
beta=0.1,
epsilon=1e-4,
reward_funcs=None,
reward_weights=None
):
"""
Compute GRPO loss and gradients using pre-generated completions.
Args:
model: The policy model
ref_model: The reference model
completion_tensors: List of tensors containing generated completions
prompt_texts: List of prompt texts
answer_texts: List of answer texts
beta: KL penalty coefficient
epsilon: Numerical stability constant
reward_funcs: List of reward functions
reward_weights: Optional weights for reward functions
"""
# Ensure model is in training mode for gradient computation
model.train()
# Get completion texts for reward calculation
completion_texts = [tokenizer.decode(comp.tolist()) for comp in completion_tensors]
# Prepare inputs for loss computation
max_length = max(tensor.shape[0] for tensor in completion_tensors)
padded_completions = []
attention_masks = []
for completion_tensor in completion_tensors:
padding_length = max_length - completion_tensor.shape[0]
if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
padded_ids = mx.concatenate([completion_tensor, padding])
mask = mx.concatenate(
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
)
else:
padded_ids = completion_tensor
mask = mx.ones_like(completion_tensor)
padded_completions.append(padded_ids)
attention_masks.append(mask)
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
# Compute log probabilities for both models
token_log_probs = get_per_token_logps(model, inputs, lengths)
if ref_model is None:
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in token_log_probs]
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in ref_token_log_probs]
# Pad log probabilities to same length
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
padded_ref_log_probs = []
for i in range(len(token_log_probs)):
seq_len = token_log_probs[i].shape[0]
padding = mx.zeros((max_len - seq_len,))
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Calculate rewards
all_func_rewards = []
for reward_func in reward_funcs:
func_rewards = mx.array(
reward_func(
prompts=prompt_texts,
completions=completion_texts,
answer=answer_texts,
)
)
all_func_rewards.append(func_rewards)
# Stack rewards and apply weights
rewards = mx.stack(all_func_rewards, axis=1)
if reward_weights is not None:
if len(reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
# Group rewards by prompt (assuming completions are grouped by prompt)
group_size = len(completion_tensors) // len(prompt_texts)
if len(completion_tensors) % len(prompt_texts) != 0:
raise ValueError("Number of completions must be divisible by number of prompts")
rewards_by_group = []
for i in range(0, len(rewards), group_size):
rewards_by_group.append(rewards[i:i+group_size])
# Calculate advantages
advantages = mx.zeros_like(rewards)
for i, group_rewards in enumerate(rewards_by_group):
if len(group_rewards) > 1: # Only normalize if we have multiple samples
mean_reward = mx.mean(group_rewards)
std_reward = mx.std(group_rewards)
for j in range(group_size):
idx = i * group_size + j
advantages[idx] = (group_rewards[j] - mean_reward) / (std_reward + epsilon)
else:
# If only one sample, advantage is 0
advantages[i * group_size] = 0.0
# Compute KL divergence
kl_div = (
mx.exp(ref_token_log_probs - token_log_probs)
- (ref_token_log_probs - token_log_probs)
- 1
)
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(token_log_probs - ref_token_log_probs)
# Compute per-token loss
per_token_loss = -(
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
)
# Average over tokens
sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate metrics for reporting
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
metrics = {
"total_rewards_mean": mx.mean(rewards),
"total_rewards_std": mx.std(rewards),
"kl": mean_kl,
}
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
func_rewards = all_func_rewards[i]
metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
metrics[f"{func_name}_std"] = mx.std(func_rewards)
return loss, sequence_lengths.sum(), metrics