This commit is contained in:
Goekdeniz-Guelmez 2025-02-10 19:45:19 +01:00
parent e5aa2c3b5d
commit b7bc811507

View File

@ -159,24 +159,19 @@ def get_per_token_logps(model, inputs, lengths):
logits = model(inputs).astype(mx.float16)
logits = logits[:, :-1, :]
targets = inputs[:, 1:]
mx.eval(logits)
per_token_logps = []
for i in range(logits.shape[0]):
seq_len = int(lengths[i]) - 1
seq_logits = logits[i, :seq_len]
seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis(
log_probs,
seq_targets.reshape(seq_len, 1),
axis=-1
).squeeze(-1)
per_token_logps.append(token_log_probs)
mx.eval(logits)
return per_token_logps
@ -270,8 +265,8 @@ def grpo_loss(
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
token_log_probs = mx.stack(padded_log_probs).astype(mx.float32)
ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32)
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Calculate rewards and advantages
rewards = mx.zeros((len(all_completions),))
@ -299,7 +294,7 @@ def grpo_loss(
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs), dtype=mx.float32))
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
# Compute per-token loss following GRPO formula
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
@ -580,7 +575,7 @@ def train_grpo(
for i, reward_func in enumerate(reward_funcs):
val_metrics_str += (
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
# f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
)
print(