This commit is contained in:
Goekdeniz-Guelmez 2025-02-09 15:41:47 +01:00
parent 39e9469059
commit 54179901b5

View File

@ -183,15 +183,24 @@ def get_per_token_logps(model, inputs, lengths):
return per_token_logps
def compute_kl(logprobs1, logprobs2):
ratio = mx.exp(logprobs1 - logprobs2)
return ratio - 1 - (logprobs1 - logprobs2)
def compute_policy_ratio(current_logprobs, ref_logprobs):
return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32))
def grpo_loss(
model,
ref_model,
tokenizer,
batch,
reward_funcs=None,
beta=0.1,
group_size=4,
epsilon=1e-4,
ref_model=None,
max_tokens=64,
temperature=1.0
):
@ -257,10 +266,10 @@ def grpo_loss(
mx.metal.clear_cache()
# Reference policy probabilities
if ref_model is not None:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
else:
if ref_model is None:
ref_token_log_probs = token_log_probs
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
@ -268,13 +277,13 @@ def grpo_loss(
for i in range(len(token_log_probs)):
seq_len = token_log_probs[i].shape[0]
padding = mx.zeros((max_len - seq_len,), dtype=mx.float16)
padding = mx.zeros((max_len - seq_len,))
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
token_log_probs = mx.stack(padded_log_probs).astype(mx.float32)
ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32)
# Calculate rewards and advantages
rewards = mx.zeros((len(all_completions),))
@ -296,26 +305,23 @@ def grpo_loss(
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
# Compute KL divergence using Schulman's approximator
kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs)
kl_div = compute_kl(token_log_probs, ref_token_log_probs)
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(ref_token_log_probs))
policy_ratio = compute_policy_ratio(token_log_probs, ref_token_log_probs)
# Compute per-token loss following GRPO formula
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
# Average over tokens and sequences
sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Average over tokens
loss = per_token_loss.sum().mean()
# Calculate mean KL divergence for metrics
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
# Collect reward metrics
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
@ -326,7 +332,6 @@ def grpo_loss(
answer=expanded_answers
))
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
metrics = {
'total_rewards_mean': mx.mean(rewards),
@ -338,11 +343,10 @@ def grpo_loss(
}
mx.metal.clear_cache()
return loss, sequence_lengths.sum(), metrics
return loss, length_mask.sum(axis=1).sum(), metrics
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""Memory-optimized version of iterate_grpo_batches"""
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
@ -585,7 +589,7 @@ def train_grpo(
for i, reward_func in enumerate(reward_funcs):
val_metrics_str += (
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
# f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
)
print(