mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 03:41:17 +08:00
nits
This commit is contained in:
parent
e5aa2c3b5d
commit
b7bc811507
@ -159,24 +159,19 @@ def get_per_token_logps(model, inputs, lengths):
|
||||
logits = model(inputs).astype(mx.float16)
|
||||
logits = logits[:, :-1, :]
|
||||
targets = inputs[:, 1:]
|
||||
|
||||
mx.eval(logits)
|
||||
per_token_logps = []
|
||||
for i in range(logits.shape[0]):
|
||||
seq_len = int(lengths[i]) - 1
|
||||
|
||||
seq_logits = logits[i, :seq_len]
|
||||
seq_targets = targets[i, :seq_len]
|
||||
|
||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
seq_targets.reshape(seq_len, 1),
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
|
||||
per_token_logps.append(token_log_probs)
|
||||
mx.eval(logits)
|
||||
return per_token_logps
|
||||
|
||||
|
||||
@ -270,8 +265,8 @@ def grpo_loss(
|
||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||
|
||||
token_log_probs = mx.stack(padded_log_probs).astype(mx.float32)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32)
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
|
||||
# Calculate rewards and advantages
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
@ -299,7 +294,7 @@ def grpo_loss(
|
||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||
|
||||
# Compute policy ratio
|
||||
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs), dtype=mx.float32))
|
||||
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
|
||||
|
||||
# Compute per-token loss following GRPO formula
|
||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||
@ -580,7 +575,7 @@ def train_grpo(
|
||||
for i, reward_func in enumerate(reward_funcs):
|
||||
val_metrics_str += (
|
||||
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||
# f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||
)
|
||||
|
||||
print(
|
||||
|
Loading…
Reference in New Issue
Block a user