mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-12-16 02:08:55 +08:00
seperate functions
This commit is contained in:
@@ -453,4 +453,177 @@ def main():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
def compute_grpo_loss_and_grad(
|
||||
model,
|
||||
ref_model,
|
||||
completion_tensors,
|
||||
prompt_texts,
|
||||
answer_texts,
|
||||
beta=0.1,
|
||||
epsilon=1e-4,
|
||||
reward_funcs=None,
|
||||
reward_weights=None
|
||||
):
|
||||
"""
|
||||
Compute GRPO loss and gradients using pre-generated completions.
|
||||
|
||||
Args:
|
||||
model: The policy model
|
||||
ref_model: The reference model
|
||||
completion_tensors: List of tensors containing generated completions
|
||||
prompt_texts: List of prompt texts
|
||||
answer_texts: List of answer texts
|
||||
beta: KL penalty coefficient
|
||||
epsilon: Numerical stability constant
|
||||
reward_funcs: List of reward functions
|
||||
reward_weights: Optional weights for reward functions
|
||||
"""
|
||||
# Ensure model is in training mode for gradient computation
|
||||
model.train()
|
||||
|
||||
# Get completion texts for reward calculation
|
||||
completion_texts = [tokenizer.decode(comp.tolist()) for comp in completion_tensors]
|
||||
|
||||
# Prepare inputs for loss computation
|
||||
max_length = max(tensor.shape[0] for tensor in completion_tensors)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
|
||||
for completion_tensor in completion_tensors:
|
||||
padding_length = max_length - completion_tensor.shape[0]
|
||||
if padding_length > 0:
|
||||
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
|
||||
padded_ids = mx.concatenate([completion_tensor, padding])
|
||||
mask = mx.concatenate(
|
||||
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
|
||||
)
|
||||
else:
|
||||
padded_ids = completion_tensor
|
||||
mask = mx.ones_like(completion_tensor)
|
||||
padded_completions.append(padded_ids)
|
||||
attention_masks.append(mask)
|
||||
|
||||
inputs = mx.stack(padded_completions)
|
||||
attention_mask = mx.stack(attention_masks)
|
||||
lengths = attention_mask.sum(axis=1)
|
||||
|
||||
# Compute log probabilities for both models
|
||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||
|
||||
if ref_model is None:
|
||||
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in token_log_probs]
|
||||
else:
|
||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in ref_token_log_probs]
|
||||
|
||||
# Pad log probabilities to same length
|
||||
max_len = max(x.shape[0] for x in token_log_probs)
|
||||
padded_log_probs = []
|
||||
padded_ref_log_probs = []
|
||||
|
||||
for i in range(len(token_log_probs)):
|
||||
seq_len = token_log_probs[i].shape[0]
|
||||
padding = mx.zeros((max_len - seq_len,))
|
||||
|
||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
|
||||
# Calculate rewards
|
||||
all_func_rewards = []
|
||||
for reward_func in reward_funcs:
|
||||
func_rewards = mx.array(
|
||||
reward_func(
|
||||
prompts=prompt_texts,
|
||||
completions=completion_texts,
|
||||
answer=answer_texts,
|
||||
)
|
||||
)
|
||||
all_func_rewards.append(func_rewards)
|
||||
|
||||
# Stack rewards and apply weights
|
||||
rewards = mx.stack(all_func_rewards, axis=1)
|
||||
if reward_weights is not None:
|
||||
if len(reward_weights) != len(reward_funcs):
|
||||
raise ValueError(
|
||||
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
|
||||
f"functions ({len(reward_funcs)})"
|
||||
)
|
||||
reward_weights = mx.array(reward_weights, dtype=mx.float32)
|
||||
else:
|
||||
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
|
||||
|
||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||
|
||||
# Group rewards by prompt (assuming completions are grouped by prompt)
|
||||
group_size = len(completion_tensors) // len(prompt_texts)
|
||||
if len(completion_tensors) % len(prompt_texts) != 0:
|
||||
raise ValueError("Number of completions must be divisible by number of prompts")
|
||||
|
||||
rewards_by_group = []
|
||||
for i in range(0, len(rewards), group_size):
|
||||
rewards_by_group.append(rewards[i:i+group_size])
|
||||
|
||||
# Calculate advantages
|
||||
advantages = mx.zeros_like(rewards)
|
||||
for i, group_rewards in enumerate(rewards_by_group):
|
||||
if len(group_rewards) > 1: # Only normalize if we have multiple samples
|
||||
mean_reward = mx.mean(group_rewards)
|
||||
std_reward = mx.std(group_rewards)
|
||||
|
||||
for j in range(group_size):
|
||||
idx = i * group_size + j
|
||||
advantages[idx] = (group_rewards[j] - mean_reward) / (std_reward + epsilon)
|
||||
else:
|
||||
# If only one sample, advantage is 0
|
||||
advantages[i * group_size] = 0.0
|
||||
|
||||
# Compute KL divergence
|
||||
kl_div = (
|
||||
mx.exp(ref_token_log_probs - token_log_probs)
|
||||
- (ref_token_log_probs - token_log_probs)
|
||||
- 1
|
||||
)
|
||||
|
||||
# Create mask for valid tokens
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(token_log_probs - ref_token_log_probs)
|
||||
|
||||
# Compute per-token loss
|
||||
per_token_loss = -(
|
||||
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
|
||||
)
|
||||
|
||||
# Average over tokens
|
||||
sequence_sums = per_token_loss.sum(axis=1)
|
||||
sequence_lengths = length_mask.sum(axis=1)
|
||||
loss = (sequence_sums / sequence_lengths).mean()
|
||||
|
||||
# Calculate metrics for reporting
|
||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||
|
||||
metrics = {
|
||||
"total_rewards_mean": mx.mean(rewards),
|
||||
"total_rewards_std": mx.std(rewards),
|
||||
"kl": mean_kl,
|
||||
}
|
||||
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
func_name = reward_func.__name__
|
||||
func_rewards = all_func_rewards[i]
|
||||
metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
|
||||
metrics[f"{func_name}_std"] = mx.std(func_rewards)
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
Reference in New Issue
Block a user