mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-30 21:31:14 +08:00
updates
This commit is contained in:
parent
54e295ea80
commit
ca32424043
@ -164,6 +164,42 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li
|
|||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
def get_per_token_logps(model, inputs, lengths):
|
||||||
|
# Get logits from model
|
||||||
|
logits = model(inputs).astype(mx.float32) # [batch_size, seq_len, vocab_size]
|
||||||
|
# Remove last position as it corresponds to the next token prediction
|
||||||
|
logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size]
|
||||||
|
targets = inputs[:, 1:] # Shift inputs to get targets
|
||||||
|
|
||||||
|
# Process sequences individually to save memory
|
||||||
|
per_token_logps = []
|
||||||
|
for i in range(logits.shape[0]):
|
||||||
|
# Get sequence length for this example
|
||||||
|
seq_len = int(lengths[i]) - 1 # -1 because we removed last position
|
||||||
|
|
||||||
|
# Get logits and targets for this sequence
|
||||||
|
seq_logits = logits[i, :seq_len] # [seq_len, vocab_size]
|
||||||
|
seq_targets = targets[i, :seq_len] # [seq_len]
|
||||||
|
|
||||||
|
# Compute log probabilities
|
||||||
|
log_probs = nn.log_softmax(seq_logits, axis=-1) # [seq_len, vocab_size]
|
||||||
|
|
||||||
|
# Gather log probs for actual tokens
|
||||||
|
token_log_probs = mx.take_along_axis(
|
||||||
|
log_probs,
|
||||||
|
seq_targets.reshape(seq_len, 1),
|
||||||
|
axis=-1
|
||||||
|
).squeeze(-1) # [seq_len]
|
||||||
|
|
||||||
|
per_token_logps.append(token_log_probs)
|
||||||
|
|
||||||
|
# Clean up intermediates
|
||||||
|
del seq_logits, seq_targets, log_probs, token_log_probs
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
return per_token_logps
|
||||||
|
|
||||||
|
|
||||||
def grpo_loss(
|
def grpo_loss(
|
||||||
model,
|
model,
|
||||||
tokenizer,
|
tokenizer,
|
||||||
@ -248,24 +284,30 @@ def grpo_loss(
|
|||||||
targets = inputs[:, 1:]
|
targets = inputs[:, 1:]
|
||||||
|
|
||||||
# Current policy probabilities
|
# Current policy probabilities
|
||||||
token_log_probs = mx.take_along_axis(
|
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||||
log_probs,
|
|
||||||
targets.reshape(*targets.shape, 1),
|
|
||||||
axis=-1
|
|
||||||
).squeeze(-1)
|
|
||||||
|
|
||||||
# Reference policy probabilities
|
# Reference policy probabilities
|
||||||
if ref_model is not None:
|
if ref_model is not None:
|
||||||
ref_logits = ref_model(inputs).astype(mx.float32)
|
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
|
||||||
else:
|
else:
|
||||||
ref_logits = mx.array(logits)
|
ref_token_log_probs = token_log_probs
|
||||||
|
|
||||||
ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1)
|
max_len = max(x.shape[0] for x in token_log_probs)
|
||||||
ref_token_log_probs = mx.take_along_axis(
|
padded_log_probs = []
|
||||||
ref_log_probs,
|
padded_ref_log_probs = []
|
||||||
targets.reshape(*targets.shape, 1),
|
|
||||||
axis=-1
|
for i in range(len(token_log_probs)):
|
||||||
).squeeze(-1)
|
seq_len = token_log_probs[i].shape[0]
|
||||||
|
padding = mx.zeros((max_len - seq_len,), dtype=mx.float32)
|
||||||
|
|
||||||
|
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
|
||||||
|
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
|
||||||
|
|
||||||
|
del padding
|
||||||
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
|
token_log_probs = mx.stack(padded_log_probs)
|
||||||
|
ref_token_log_probs = mx.stack(padded_ref_log_probs)
|
||||||
|
|
||||||
# Calculate rewards and advantages
|
# Calculate rewards and advantages
|
||||||
rewards = mx.zeros((len(all_completions),))
|
rewards = mx.zeros((len(all_completions),))
|
||||||
|
Loading…
Reference in New Issue
Block a user