mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
fix
This commit is contained in:
parent
39e9469059
commit
54179901b5
@ -183,15 +183,24 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
return per_token_logps
|
||||
|
||||
|
||||
def compute_kl(logprobs1, logprobs2):
|
||||
ratio = mx.exp(logprobs1 - logprobs2)
|
||||
return ratio - 1 - (logprobs1 - logprobs2)
|
||||
|
||||
|
||||
def compute_policy_ratio(current_logprobs, ref_logprobs):
|
||||
return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32))
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
ref_model,
|
||||
tokenizer,
|
||||
batch,
|
||||
reward_funcs=None,
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
epsilon=1e-4,
|
||||
ref_model=None,
|
||||
max_tokens=64,
|
||||
temperature=1.0
|
||||
):
|
||||
@ -257,10 +266,10 @@ def grpo_loss(
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# Reference policy probabilities
|
||||
if ref_model is not None:
|
||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||
else:
|
||||
if ref_model is None:
|
||||
ref_token_log_probs = token_log_probs
|
||||
else:
|
||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||
|
||||
max_len = max(x.shape[0] for x in token_log_probs)
|
||||
padded_log_probs = []
|
||||
@ -268,13 +277,13 @@ def grpo_loss(
|
||||
|
||||
for i in range(len(token_log_probs)):
|
||||
seq_len = token_log_probs[i].shape[0]
|
||||
padding = mx.zeros((max_len - seq_len,), dtype=mx.float16)
|
||||
padding = mx.zeros((max_len - seq_len,))
|
||||
|
||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
token_log_probs = mx.stack(padded_log_probs).astype(mx.float32)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32)
|
||||
|
||||
# Calculate rewards and advantages
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
@ -296,26 +305,23 @@ def grpo_loss(
|
||||
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||
|
||||
# Compute KL divergence using Schulman's approximator
|
||||
kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs)
|
||||
kl_div = compute_kl(token_log_probs, ref_token_log_probs)
|
||||
|
||||
# Create mask for valid tokens
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(ref_token_log_probs))
|
||||
policy_ratio = compute_policy_ratio(token_log_probs, ref_token_log_probs)
|
||||
|
||||
# Compute per-token loss following GRPO formula
|
||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||
|
||||
# Average over tokens and sequences
|
||||
sequence_sums = per_token_loss.sum(axis=1)
|
||||
sequence_lengths = length_mask.sum(axis=1)
|
||||
|
||||
loss = (sequence_sums / sequence_lengths).mean()
|
||||
# Average over tokens
|
||||
loss = per_token_loss.sum().mean()
|
||||
|
||||
# Calculate mean KL divergence for metrics
|
||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||
|
||||
|
||||
# Collect reward metrics
|
||||
reward_metrics = {}
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
@ -326,7 +332,6 @@ def grpo_loss(
|
||||
answer=expanded_answers
|
||||
))
|
||||
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
||||
|
||||
metrics = {
|
||||
'total_rewards_mean': mx.mean(rewards),
|
||||
@ -338,11 +343,10 @@ def grpo_loss(
|
||||
}
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
return loss, length_mask.sum(axis=1).sum(), metrics
|
||||
|
||||
|
||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""Memory-optimized version of iterate_grpo_batches"""
|
||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||
|
||||
@ -585,7 +589,7 @@ def train_grpo(
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
val_metrics_str += (
|
||||
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||
# f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||
)
|
||||
|
||||
print(
|
||||
|
Loading…
Reference in New Issue
Block a user