mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-28 20:25:22 +08:00
nits
This commit is contained in:
parent
e5aa2c3b5d
commit
b7bc811507
@ -159,24 +159,19 @@ def get_per_token_logps(model, inputs, lengths):
|
|||||||
logits = model(inputs).astype(mx.float16)
|
logits = model(inputs).astype(mx.float16)
|
||||||
logits = logits[:, :-1, :]
|
logits = logits[:, :-1, :]
|
||||||
targets = inputs[:, 1:]
|
targets = inputs[:, 1:]
|
||||||
|
mx.eval(logits)
|
||||||
per_token_logps = []
|
per_token_logps = []
|
||||||
for i in range(logits.shape[0]):
|
for i in range(logits.shape[0]):
|
||||||
seq_len = int(lengths[i]) - 1
|
seq_len = int(lengths[i]) - 1
|
||||||
|
|
||||||
seq_logits = logits[i, :seq_len]
|
seq_logits = logits[i, :seq_len]
|
||||||
seq_targets = targets[i, :seq_len]
|
seq_targets = targets[i, :seq_len]
|
||||||
|
|
||||||
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
log_probs = nn.log_softmax(seq_logits, axis=-1)
|
||||||
|
|
||||||
token_log_probs = mx.take_along_axis(
|
token_log_probs = mx.take_along_axis(
|
||||||
log_probs,
|
log_probs,
|
||||||
seq_targets.reshape(seq_len, 1),
|
seq_targets.reshape(seq_len, 1),
|
||||||
axis=-1
|
axis=-1
|
||||||
).squeeze(-1)
|
).squeeze(-1)
|
||||||
|
|
||||||
per_token_logps.append(token_log_probs)
|
per_token_logps.append(token_log_probs)
|
||||||
mx.eval(logits)
|
|
||||||
return per_token_logps
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
@ -270,8 +265,8 @@ def grpo_loss(
|
|||||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||||
|
|
||||||
token_log_probs = mx.stack(padded_log_probs).astype(mx.float32)
|
token_log_probs = mx.stack(padded_log_probs)
|
||||||
ref_token_log_probs = mx.stack(padded_ref_log_probs).astype(mx.float32)
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||||
|
|
||||||
# Calculate rewards and advantages
|
# Calculate rewards and advantages
|
||||||
rewards = mx.zeros((len(all_completions),))
|
rewards = mx.zeros((len(all_completions),))
|
||||||
@ -299,7 +294,7 @@ def grpo_loss(
|
|||||||
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
|
||||||
|
|
||||||
# Compute policy ratio
|
# Compute policy ratio
|
||||||
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs), dtype=mx.float32))
|
policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)))
|
||||||
|
|
||||||
# Compute per-token loss following GRPO formula
|
# Compute per-token loss following GRPO formula
|
||||||
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask)
|
||||||
@ -580,7 +575,7 @@ def train_grpo(
|
|||||||
for i, reward_func in enumerate(reward_funcs):
|
for i, reward_func in enumerate(reward_funcs):
|
||||||
val_metrics_str += (
|
val_metrics_str += (
|
||||||
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
|
||||||
# f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
|
||||||
)
|
)
|
||||||
|
|
||||||
print(
|
print(
|
||||||
|
Loading…
Reference in New Issue
Block a user