mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 03:41:17 +08:00
fix
This commit is contained in:
parent
39e9469059
commit
54179901b5
@ -183,15 +183,24 @@ def get_per_token_logps(model, inputs, lengths):
|
|||||||
return per_token_logps
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
|
def compute_kl(logprobs1, logprobs2):
|
||||||
|
ratio = mx.exp(logprobs1 - logprobs2)
|
||||||
|
return ratio - 1 - (logprobs1 - logprobs2)
|
||||||
|
|
||||||
|
|
||||||
|
def compute_policy_ratio(current_logprobs, ref_logprobs):
|
||||||
|
return mx.exp(mx.array(current_logprobs - ref_logprobs, dtype=mx.float32))
|
||||||
|
|
||||||
|
|
||||||
def grpo_loss(
|
def grpo_loss(
|
||||||
model,
|
model,
|
||||||
|
ref_model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
batch,
|
batch,
|
||||||
reward_funcs=None,
|
reward_funcs=None,
|
||||||
beta=0.1,
|
beta=0.1,
|
||||||
group_size=4,
|
group_size=4,
|
||||||
epsilon=1e-4,
|
epsilon=1e-4,
|
||||||
ref_model=None,
|
|
||||||
max_tokens=64,
|
max_tokens=64,
|
||||||
temperature=1.0
|
temperature=1.0
|
||||||
):
|
):
|
||||||
@ -257,10 +266,10 @@ def grpo_loss(
|
|||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
# Reference policy probabilities
|
# Reference policy probabilities
|
||||||
if ref_model is not None:
|
if ref_model is None:
|
||||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
|
||||||
else:
|
|
||||||
ref_token_log_probs = token_log_probs
|
ref_token_log_probs = token_log_probs
|
||||||
|
else:
|
||||||
|
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||||
|
|
||||||
max_len = max(x.shape[0] for x in token_log_probs)
|
max_len = max(x.shape[0] for x in token_log_probs)
|
||||||
padded_log_probs = []
|
padded_log_probs = []
|
||||||
@ -268,13 +277,13 @@ def grpo_loss(
|
|||||||
|
|
||||||
for i in range(len(token_log_probs)):
|
for i in range(len(token_log_probs)):
|
||||||
seq_len = token_log_probs[i].shape[0]
|
seq_len = token_log_probs[i].shape[0]
|
||||||
padding = mx.zeros((max_len - seq_len,), dtype=mx.float16)
|
padding = mx.zeros((max_len - seq_len,))
|
||||||
|
|
||||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||||
|
|
||||||
token_log_probs = mx.stack(padded_log_probs)
|
token_log_probs = mx.stack(padded_log_probs).astype(mx.float32)
|
||||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32)
|
||||||
|
|
||||||
# Calculate rewards and advantages
|
# Calculate rewards and advantages
|
||||||
rewards = mx.zeros((len(all_completions),))
|
rewards = mx.zeros((len(all_completions),))
|
||||||
@ -296,22 +305,19 @@ def grpo_loss(
|
|||||||
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||||
|
|
||||||
# Compute KL divergence using Schulman's approximator
|
# Compute KL divergence using Schulman's approximator
|
||||||
kl_div = (mx.exp(token_log_probs - ref_token_log_probs) - 1) - (token_log_probs - ref_token_log_probs)
|
kl_div = compute_kl(token_log_probs, ref_token_log_probs)
|
||||||
|
|
||||||
# Create mask for valid tokens
|
# Create mask for valid tokens
|
||||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||||
|
|
||||||
# Compute policy ratio
|
# Compute policy ratio
|
||||||
policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(ref_token_log_probs))
|
policy_ratio = compute_policy_ratio(token_log_probs, ref_token_log_probs)
|
||||||
|
|
||||||
# Compute per-token loss following GRPO formula
|
# Compute per-token loss following GRPO formula
|
||||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||||
|
|
||||||
# Average over tokens and sequences
|
# Average over tokens
|
||||||
sequence_sums = per_token_loss.sum(axis=1)
|
loss = per_token_loss.sum().mean()
|
||||||
sequence_lengths = length_mask.sum(axis=1)
|
|
||||||
|
|
||||||
loss = (sequence_sums / sequence_lengths).mean()
|
|
||||||
|
|
||||||
# Calculate mean KL divergence for metrics
|
# Calculate mean KL divergence for metrics
|
||||||
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
|
||||||
@ -326,7 +332,6 @@ def grpo_loss(
|
|||||||
answer=expanded_answers
|
answer=expanded_answers
|
||||||
))
|
))
|
||||||
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||||
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
|
||||||
|
|
||||||
metrics = {
|
metrics = {
|
||||||
'total_rewards_mean': mx.mean(rewards),
|
'total_rewards_mean': mx.mean(rewards),
|
||||||
@ -338,11 +343,10 @@ def grpo_loss(
|
|||||||
}
|
}
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
return loss, sequence_lengths.sum(), metrics
|
return loss, length_mask.sum(axis=1).sum(), metrics
|
||||||
|
|
||||||
|
|
||||||
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||||
"""Memory-optimized version of iterate_grpo_batches"""
|
|
||||||
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
|
||||||
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples")
|
||||||
|
|
||||||
@ -585,7 +589,7 @@ def train_grpo(
|
|||||||
for i, reward_func in enumerate(reward_funcs):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
val_metrics_str += (
|
val_metrics_str += (
|
||||||
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||||
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
# f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
Loading…
Reference in New Issue
Block a user