mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 09:51:19 +08:00
updates
This commit is contained in:
parent
80e10a59d7
commit
925e11439b
@ -76,7 +76,6 @@ def generate_grpo(
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_samples, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_samples)
|
||||
batch_results = []
|
||||
|
||||
if is_training:
|
||||
# Training mode with batched processing
|
||||
@ -92,26 +91,46 @@ def generate_grpo(
|
||||
|
||||
# Track tokens for each sequence in the batch
|
||||
batch_tokens = [[] for _ in range(batch_end - batch_start)]
|
||||
active_indices = list(range(batch_end - batch_start))
|
||||
|
||||
# Initial token generation for all sequences in batch
|
||||
for i in range(len(batch_logits)):
|
||||
logits_temp = batch_logits[i] / temperature
|
||||
next_token = mx.random.categorical(logits_temp)
|
||||
token = next_token.item()
|
||||
mx.eval(logits_temp, next_token, token)
|
||||
batch_tokens[i].append(token)
|
||||
|
||||
# Check if this token already completes the sequence
|
||||
if token == tokenizer.eos_token_id:
|
||||
continue
|
||||
else:
|
||||
# Set up for next token
|
||||
current_input = mx.array([token])
|
||||
batch_logits[i] = model(current_input[None], cache=prompt_caches[i])[:, -1]
|
||||
|
||||
mx.eval(batch_logits)
|
||||
active_indices = [i for i, tokens in enumerate(batch_tokens) if tokens[-1] != tokenizer.eos_token_id and len(tokens) < max_tokens]
|
||||
|
||||
# Generate tokens until all sequences are complete
|
||||
while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens:
|
||||
next_active = []
|
||||
for idx in active_indices:
|
||||
logits_temp = batch_logits[idx] / temperature
|
||||
probs = nn.softmax(logits_temp, axis=-1)
|
||||
next_token = mx.random.categorical(logits_temp)
|
||||
token = next_token.item()
|
||||
|
||||
test_sequence = batch_tokens[idx] + [token]
|
||||
is_end = (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
))
|
||||
|
||||
mx.eval(logits_temp, next_token, token)
|
||||
batch_tokens[idx].append(token)
|
||||
|
||||
# Check for end sequence
|
||||
if len(batch_tokens[idx]) >= len(end_sequence):
|
||||
test_sequence = batch_tokens[idx][-len(end_sequence):]
|
||||
is_end = mx.array_equal(
|
||||
mx.array(test_sequence),
|
||||
end_sequence
|
||||
)
|
||||
else:
|
||||
is_end = False
|
||||
|
||||
if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens:
|
||||
# This sequence is done
|
||||
pass
|
||||
@ -123,12 +142,31 @@ def generate_grpo(
|
||||
|
||||
mx.eval([batch_logits[idx] for idx in next_active])
|
||||
active_indices = next_active
|
||||
|
||||
# Clear caches after processing this batch
|
||||
for pc in prompt_caches:
|
||||
del pc
|
||||
|
||||
# Add batch results to overall results
|
||||
for tokens in batch_tokens:
|
||||
if tokens:
|
||||
results.append(mx.array(tokens))
|
||||
|
||||
# Filter out any special tokens that might appear after the end token
|
||||
if len(tokens) >= len(end_sequence):
|
||||
for i in range(len(tokens) - len(end_sequence) + 1):
|
||||
if mx.array_equal(
|
||||
mx.array(tokens[i:i+len(end_sequence)]),
|
||||
end_sequence
|
||||
):
|
||||
tokens = tokens[:i+len(end_sequence)]
|
||||
break
|
||||
|
||||
# Filter out EOS token if it's the last token
|
||||
if tokens and tokens[-1] == tokenizer.eos_token_id:
|
||||
tokens = tokens[:-1]
|
||||
|
||||
# Only add non-empty token lists
|
||||
if tokens:
|
||||
results.append(mx.array(tokens))
|
||||
else:
|
||||
# Non-training mode with batched processing
|
||||
for idx in range(batch_start, batch_end):
|
||||
@ -158,7 +196,6 @@ def generate_grpo(
|
||||
results.append(mx.array(current_tokens))
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
mx.eval(results)
|
||||
return results
|
||||
|
||||
@ -267,14 +304,7 @@ def grpo_loss(
|
||||
|
||||
# If we didn't generate any completions, return early
|
||||
if not all_completions:
|
||||
print("No completions were generated. Returning zero loss.")
|
||||
dummy_loss = mx.zeros(1)
|
||||
dummy_metrics = {
|
||||
'total_rewards_mean': mx.zeros(1),
|
||||
'total_rewards_std': mx.zeros(1),
|
||||
'kl': mx.zeros(1)
|
||||
}
|
||||
return dummy_loss, mx.array(0), dummy_metrics
|
||||
raise ValueError("No completions were generated. Please check your model and inputs.")
|
||||
|
||||
# Create expanded prompts and answers based on actual generated completions
|
||||
expanded_answers = []
|
||||
@ -453,11 +483,24 @@ def grpo_loss(
|
||||
|
||||
if is_validation and all_completion_texts:
|
||||
print("\n=== Validation Sample Details ===")
|
||||
|
||||
# Print the input context (prompt)
|
||||
last_prompt_idx = batch_indices[-1] if batch_indices else 0
|
||||
|
||||
if last_prompt_idx < len(prompt_text):
|
||||
print(f"\n📋 Raw Prompt:\n{prompt_text[last_prompt_idx]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
|
||||
# Get the actual tokenized prompt that was fed to the model
|
||||
if last_prompt_idx < len(prompt_tokens):
|
||||
actual_prompt = tokenizer.decode(prompt_tokens[last_prompt_idx])
|
||||
print(f"\n🔄 Model Input:\n{actual_prompt}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
|
||||
print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
|
||||
# Make sure we have a valid index for answer_text
|
||||
last_prompt_idx = batch_indices[-1] if batch_indices else 0
|
||||
if last_prompt_idx < len(answer_text):
|
||||
print(f"\n✅ Answer:\n{answer_text[last_prompt_idx]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
|
Loading…
Reference in New Issue
Block a user