This commit is contained in:
Goekdeniz-Guelmez 2025-02-03 21:57:26 +01:00
parent 54e295ea80
commit ca32424043

View File

@ -164,6 +164,42 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
return scores
def get_per_token_logps(model, inputs, lengths):
# Get logits from model
logits = model(inputs).astype(mx.float32) # [batch_size, seq_len, vocab_size]
# Remove last position as it corresponds to the next token prediction
logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
targets = inputs[:, 1:] # Shift inputs to get targets
# Process sequences individually to save memory
per_token_logps = []
for i in range(logits.shape[0]):
# Get sequence length for this example
seq_len = int(lengths[i]) - 1 # -1 because we removed last position
# Get logits and targets for this sequence
seq_logits = logits[i, :seq_len] # [seq_len, vocab_size]
seq_targets = targets[i, :seq_len] # [seq_len]
# Compute log probabilities
log_probs = nn.log_softmax(seq_logits, axis=-1) # [seq_len, vocab_size]
# Gather log probs for actual tokens
token_log_probs = mx.take_along_axis(
log_probs,
seq_targets.reshape(seq_len, 1),
axis=-1
).squeeze(-1) # [seq_len]
per_token_logps.append(token_log_probs)
# Clean up intermediates
del seq_logits, seq_targets, log_probs, token_log_probs
mx.metal.clear_cache()
return per_token_logps
def grpo_loss(
model,
tokenizer,
@ -248,24 +284,30 @@ def grpo_loss(
targets = inputs[:, 1:]
# Current policy probabilities
token_log_probs = mx.take_along_axis(
log_probs,
targets.reshape(*targets.shape, 1),
axis=-1
).squeeze(-1)
token_log_probs = get_per_token_logps(model, inputs, lengths)
# Reference policy probabilities
if ref_model is not None:
ref_logits = ref_model(inputs).astype(mx.float32)
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
else:
ref_logits = mx.array(logits)
ref_token_log_probs = token_log_probs
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
ref_token_log_probs = mx.take_along_axis(
ref_log_probs,
targets.reshape(*targets.shape, 1),
axis=-1
).squeeze(-1)
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
padded_ref_log_probs = []
for i in range(len(token_log_probs)):
seq_len = token_log_probs[i].shape[0]
padding = mx.zeros((max_len - seq_len,), dtype=mx.float32)
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
del padding
mx.metal.clear_cache()
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Calculate rewards and advantages
rewards = mx.zeros((len(all_completions),))