mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-29 21:01:13 +08:00
updates
This commit is contained in:
parent
54e295ea80
commit
ca32424043
@ -164,6 +164,42 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
||||
return scores
|
||||
|
||||
|
||||
def get_per_token_logps(model, inputs, lengths):
|
||||
# Get logits from model
|
||||
logits = model(inputs).astype(mx.float32) # [batch_size, seq_len, vocab_size]
|
||||
# Remove last position as it corresponds to the next token prediction
|
||||
logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||
targets = inputs[:, 1:] # Shift inputs to get targets
|
||||
|
||||
# Process sequences individually to save memory
|
||||
per_token_logps = []
|
||||
for i in range(logits.shape[0]):
|
||||
# Get sequence length for this example
|
||||
seq_len = int(lengths[i]) - 1 # -1 because we removed last position
|
||||
|
||||
# Get logits and targets for this sequence
|
||||
seq_logits = logits[i, :seq_len] # [seq_len, vocab_size]
|
||||
seq_targets = targets[i, :seq_len] # [seq_len]
|
||||
|
||||
# Compute log probabilities
|
||||
log_probs = nn.log_softmax(seq_logits, axis=-1) # [seq_len, vocab_size]
|
||||
|
||||
# Gather log probs for actual tokens
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
seq_targets.reshape(seq_len, 1),
|
||||
axis=-1
|
||||
).squeeze(-1) # [seq_len]
|
||||
|
||||
per_token_logps.append(token_log_probs)
|
||||
|
||||
# Clean up intermediates
|
||||
del seq_logits, seq_targets, log_probs, token_log_probs
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return per_token_logps
|
||||
|
||||
|
||||
def grpo_loss(
|
||||
model,
|
||||
tokenizer,
|
||||
@ -248,24 +284,30 @@ def grpo_loss(
|
||||
targets = inputs[:, 1:]
|
||||
|
||||
# Current policy probabilities
|
||||
token_log_probs = mx.take_along_axis(
|
||||
log_probs,
|
||||
targets.reshape(*targets.shape, 1),
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||
|
||||
# Reference policy probabilities
|
||||
if ref_model is not None:
|
||||
ref_logits = ref_model(inputs).astype(mx.float32)
|
||||
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||
else:
|
||||
ref_logits = mx.array(logits)
|
||||
ref_token_log_probs = token_log_probs
|
||||
|
||||
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
||||
ref_token_log_probs = mx.take_along_axis(
|
||||
ref_log_probs,
|
||||
targets.reshape(*targets.shape, 1),
|
||||
axis=-1
|
||||
).squeeze(-1)
|
||||
max_len = max(x.shape[0] for x in token_log_probs)
|
||||
padded_log_probs = []
|
||||
padded_ref_log_probs = []
|
||||
|
||||
for i in range(len(token_log_probs)):
|
||||
seq_len = token_log_probs[i].shape[0]
|
||||
padding = mx.zeros((max_len - seq_len,), dtype=mx.float32)
|
||||
|
||||
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||
|
||||
del padding
|
||||
mx.metal.clear_cache()
|
||||
|
||||
token_log_probs = mx.stack(padded_log_probs)
|
||||
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||
|
||||
# Calculate rewards and advantages
|
||||
rewards = mx.zeros((len(all_completions),))
|
||||
|
Loading…
Reference in New Issue
Block a user