mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-26 10:41:18 +08:00
batching fix
This commit is contained in:
parent
a04eb02257
commit
15d53279ae
@ -51,74 +51,114 @@ class GRPOTrainingArgs(TrainingArgs):
|
||||
|
||||
|
||||
def generate_grpo(
|
||||
model: nn.Module,
|
||||
prompts,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=False,
|
||||
end_token: str = "</answer>",
|
||||
temperature: float = 0.8
|
||||
):
|
||||
model: nn.Module,
|
||||
prompts,
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=False,
|
||||
end_token: str = "</answer>",
|
||||
temperature: float = 0.8,
|
||||
batch_size: int = 1
|
||||
):
|
||||
if len(prompts.shape) == 1:
|
||||
prompts = prompts[None, :]
|
||||
if prompts.shape[1] == 0:
|
||||
return None
|
||||
batch_size = prompts.shape[0] * group_size
|
||||
|
||||
total_samples = prompts.shape[0] * group_size
|
||||
expanded_prompts = mx.repeat(prompts, group_size, axis=0)
|
||||
end_sequence = mx.array(tokenizer.encode(end_token))
|
||||
results = []
|
||||
mx.eval(expanded_prompts)
|
||||
try:
|
||||
for idx in range(batch_size):
|
||||
current_tokens = []
|
||||
if is_training:
|
||||
current_input = expanded_prompts[idx]
|
||||
prompt_cache = cache.make_prompt_cache(model)
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(logits, prompt_cache)
|
||||
while len(current_tokens) < max_tokens:
|
||||
logits_temp = logits / temperature
|
||||
probs = nn.softmax(logits_temp, axis=-1)
|
||||
next_token = mx.random.categorical(logits_temp)
|
||||
token = next_token.item()
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
)):
|
||||
current_tokens.append(token)
|
||||
break
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
current_input = mx.array([token])
|
||||
logits = model(current_input[None], cache=prompt_cache)[:, -1]
|
||||
mx.eval(current_input, logits, probs, next_token, token)
|
||||
else:
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.random.categorical(x / temperature)
|
||||
)
|
||||
for token, _ in generator:
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
)):
|
||||
current_tokens.append(token)
|
||||
break
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
if current_tokens:
|
||||
results.append(mx.array(current_tokens))
|
||||
try:
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_samples, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_samples)
|
||||
batch_results = []
|
||||
|
||||
if is_training:
|
||||
# Training mode with batched processing
|
||||
batch_inputs = expanded_prompts[batch_start:batch_end]
|
||||
prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
|
||||
|
||||
# Initial forward pass for all prompts in batch
|
||||
batch_logits = []
|
||||
for i, prompt in enumerate(batch_inputs):
|
||||
logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
|
||||
batch_logits.append(logits)
|
||||
mx.eval(batch_logits, prompt_caches)
|
||||
|
||||
# Track tokens for each sequence in the batch
|
||||
batch_tokens = [[] for _ in range(batch_end - batch_start)]
|
||||
active_indices = list(range(batch_end - batch_start))
|
||||
|
||||
# Generate tokens until all sequences are complete
|
||||
while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens:
|
||||
next_active = []
|
||||
for idx in active_indices:
|
||||
logits_temp = batch_logits[idx] / temperature
|
||||
probs = nn.softmax(logits_temp, axis=-1)
|
||||
next_token = mx.random.categorical(logits_temp)
|
||||
token = next_token.item()
|
||||
|
||||
test_sequence = batch_tokens[idx] + [token]
|
||||
is_end = (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
))
|
||||
|
||||
batch_tokens[idx].append(token)
|
||||
|
||||
if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens:
|
||||
# This sequence is done
|
||||
pass
|
||||
else:
|
||||
# Continue with this sequence
|
||||
next_active.append(idx)
|
||||
current_input = mx.array([token])
|
||||
batch_logits[idx] = model(current_input[None], cache=prompt_caches[idx])[:, -1]
|
||||
|
||||
mx.eval([batch_logits[idx] for idx in next_active])
|
||||
active_indices = next_active
|
||||
|
||||
# Add batch results to overall results
|
||||
for tokens in batch_tokens:
|
||||
if tokens:
|
||||
results.append(mx.array(tokens))
|
||||
|
||||
else:
|
||||
# Non-training mode with batched processing
|
||||
for idx in range(batch_start, batch_end):
|
||||
current_tokens = []
|
||||
generator = generate_step(
|
||||
expanded_prompts[idx],
|
||||
model,
|
||||
max_tokens=max_tokens,
|
||||
sampler=lambda x: mx.random.categorical(x / temperature)
|
||||
)
|
||||
|
||||
for token, _ in generator:
|
||||
test_sequence = current_tokens + [token]
|
||||
if (len(test_sequence) >= len(end_sequence) and
|
||||
mx.array_equal(
|
||||
mx.array(test_sequence[-len(end_sequence):]),
|
||||
end_sequence
|
||||
)):
|
||||
current_tokens.append(token)
|
||||
break
|
||||
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
current_tokens.append(token)
|
||||
|
||||
if current_tokens:
|
||||
results.append(mx.array(current_tokens))
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
mx.eval(results)
|
||||
return results
|
||||
|
||||
@ -151,24 +191,39 @@ def grpo_loss(
|
||||
ref_model,
|
||||
tokenizer,
|
||||
batch,
|
||||
reward_funcs=None,
|
||||
beta=0.1,
|
||||
group_size=4,
|
||||
epsilon=1e-4,
|
||||
max_tokens=64,
|
||||
temperature=1.0,
|
||||
reward_weights=None,
|
||||
is_validation=False
|
||||
reward_funcs: Optional[List[RewardFunctions]] = None,
|
||||
beta: float =0.1,
|
||||
group_size: int = 4,
|
||||
epsilon: float = 1e-4,
|
||||
max_tokens: int = 64,
|
||||
temperature: float = 0.8,
|
||||
reward_weights: Optional[List[float]] = None,
|
||||
is_validation: bool = False,
|
||||
batch_size: int = 1
|
||||
):
|
||||
prompt_tokens, _, prompt_text, answer_text = batch
|
||||
batch_size = len(prompt_tokens)
|
||||
total_samples = len(prompt_tokens)
|
||||
|
||||
all_completions = []
|
||||
all_completion_texts = []
|
||||
batch_indices = [] # Keep track of which batch each completion belongs to
|
||||
|
||||
for i in range(0, batch_size, batch_size):
|
||||
batch_prompts = prompt_tokens[i:i+batch_size]
|
||||
prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts])
|
||||
# Process in smaller batches
|
||||
for i in range(0, total_samples, batch_size):
|
||||
# Get actual batch size for this iteration (might be smaller for the last batch)
|
||||
current_batch_size = min(batch_size, total_samples - i)
|
||||
batch_prompts = prompt_tokens[i:i+current_batch_size]
|
||||
|
||||
# Pad sequences to the same length
|
||||
max_prompt_len = max(len(p) for p in batch_prompts)
|
||||
padded_prompts = []
|
||||
|
||||
for prompt in batch_prompts:
|
||||
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
|
||||
padded_prompts.append(prompt + padding)
|
||||
|
||||
# Convert to tensor
|
||||
prompt_tensor = mx.array(padded_prompts)
|
||||
|
||||
try:
|
||||
if is_validation:
|
||||
@ -178,7 +233,8 @@ def grpo_loss(
|
||||
max_tokens,
|
||||
tokenizer,
|
||||
group_size,
|
||||
temperature=temperature
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size
|
||||
)
|
||||
model.train()
|
||||
else:
|
||||
@ -189,26 +245,69 @@ def grpo_loss(
|
||||
tokenizer,
|
||||
group_size,
|
||||
is_training=True,
|
||||
temperature=temperature
|
||||
temperature=temperature,
|
||||
batch_size=current_batch_size
|
||||
)
|
||||
|
||||
if completions is not None:
|
||||
for completion_ids in completions:
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
mx.eval(completion_ids)
|
||||
for j, completion_ids in enumerate(completions):
|
||||
# Calculate which prompt this completion belongs to
|
||||
prompt_idx = i + (j // group_size)
|
||||
if prompt_idx < total_samples: # Make sure we don't go out of bounds
|
||||
batch_indices.append(prompt_idx)
|
||||
completion_text = tokenizer.decode(completion_ids.tolist())
|
||||
all_completions.append(completion_ids)
|
||||
all_completion_texts.append(completion_text)
|
||||
mx.eval(completion_ids)
|
||||
except Exception as e:
|
||||
print(f"Generation error: {e}")
|
||||
continue
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
# If we didn't generate any completions, return early
|
||||
if not all_completions:
|
||||
print("No completions were generated. Returning zero loss.")
|
||||
dummy_loss = mx.zeros(1)
|
||||
dummy_metrics = {
|
||||
'total_rewards_mean': mx.zeros(1),
|
||||
'total_rewards_std': mx.zeros(1),
|
||||
'kl': mx.zeros(1)
|
||||
}
|
||||
return dummy_loss, mx.array(0), dummy_metrics
|
||||
|
||||
# Create expanded prompts and answers based on actual generated completions
|
||||
expanded_answers = []
|
||||
expanded_prompts = []
|
||||
for i in range(batch_size):
|
||||
expanded_answers.extend([answer_text[i]] * group_size)
|
||||
expanded_prompts.extend([prompt_text[i]] * group_size)
|
||||
|
||||
# Group completions by their original prompt
|
||||
unique_prompt_indices = sorted(set(batch_indices))
|
||||
grouped_completions = {idx: [] for idx in unique_prompt_indices}
|
||||
|
||||
for i, completion_idx in enumerate(batch_indices):
|
||||
grouped_completions[completion_idx].append(i)
|
||||
|
||||
# Rebuild completions in the correct order
|
||||
ordered_completions = []
|
||||
ordered_completion_texts = []
|
||||
ordered_batch_indices = []
|
||||
|
||||
for prompt_idx in unique_prompt_indices:
|
||||
completion_indices = grouped_completions[prompt_idx]
|
||||
for idx in completion_indices:
|
||||
ordered_completions.append(all_completions[idx])
|
||||
ordered_completion_texts.append(all_completion_texts[idx])
|
||||
ordered_batch_indices.append(prompt_idx)
|
||||
|
||||
# Add corresponding prompt and answer
|
||||
expanded_prompts.append(prompt_text[prompt_idx])
|
||||
expanded_answers.append(answer_text[prompt_idx])
|
||||
|
||||
all_completions = ordered_completions
|
||||
all_completion_texts = ordered_completion_texts
|
||||
batch_indices = ordered_batch_indices
|
||||
|
||||
# Continue with the rest of the function
|
||||
max_length = max(ids.shape[0] for ids in all_completions)
|
||||
padded_completions = []
|
||||
attention_masks = []
|
||||
@ -231,7 +330,6 @@ def grpo_loss(
|
||||
|
||||
# Current policy probabilities
|
||||
token_log_probs = get_per_token_logps(model, inputs, lengths)
|
||||
|
||||
mx.eval(token_log_probs)
|
||||
|
||||
if ref_model is None:
|
||||
@ -282,11 +380,31 @@ def grpo_loss(
|
||||
|
||||
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
|
||||
|
||||
# Reshape rewards and compute advantages
|
||||
rewards_reshaped = rewards.reshape(batch_size, group_size)
|
||||
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
|
||||
advantages = (rewards - mean_rewards) / (std_rewards + epsilon)
|
||||
# Get number of unique prompts
|
||||
num_unique_prompts = len(unique_prompt_indices)
|
||||
|
||||
# Reshape rewards based on actual groups
|
||||
rewards_by_prompt = [[] for _ in range(num_unique_prompts)]
|
||||
for i, prompt_idx in enumerate(batch_indices):
|
||||
prompt_position = unique_prompt_indices.index(prompt_idx)
|
||||
rewards_by_prompt[prompt_position].append(rewards[i])
|
||||
|
||||
# Calculate advantages within each group
|
||||
advantages = mx.zeros_like(rewards)
|
||||
for i, prompt_rewards in enumerate(rewards_by_prompt):
|
||||
if len(prompt_rewards) > 1: # Only normalize if we have multiple samples
|
||||
prompt_rewards = mx.array(prompt_rewards)
|
||||
mean_reward = mx.mean(prompt_rewards)
|
||||
std_reward = mx.std(prompt_rewards)
|
||||
|
||||
# Find indices for this prompt
|
||||
indices = [j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i]]
|
||||
for j, idx in enumerate(indices):
|
||||
advantages[idx] = (prompt_rewards[j] - mean_reward) / (std_reward + epsilon)
|
||||
else:
|
||||
# If only one sample, advantage is 0
|
||||
idx = batch_indices.index(unique_prompt_indices[i])
|
||||
advantages[idx] = 0.0
|
||||
|
||||
# Compute KL divergence using Schulman's approximator
|
||||
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
|
||||
@ -320,23 +438,35 @@ def grpo_loss(
|
||||
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
|
||||
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
|
||||
|
||||
|
||||
grouped_rewards_mean = mx.array([mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt])
|
||||
grouped_rewards_std = mx.array([mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt])
|
||||
|
||||
metrics = {
|
||||
'total_rewards_mean': mx.mean(rewards),
|
||||
'total_rewards_std': mx.std(rewards),
|
||||
'grouped_rewards_mean': mx.mean(rewards_reshaped),
|
||||
'grouped_rewards_std': mx.std(rewards_reshaped),
|
||||
'grouped_rewards_mean': mx.mean(grouped_rewards_mean),
|
||||
'grouped_rewards_std': mx.mean(grouped_rewards_std),
|
||||
'kl': mean_kl,
|
||||
**reward_metrics
|
||||
}
|
||||
|
||||
if is_validation:
|
||||
if is_validation and all_completion_texts:
|
||||
print("\n=== Validation Sample Details ===")
|
||||
print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
print(f"\n✅ Answer:\n{answer_text[-1]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
|
||||
print("\n" + "="*30 + "\n")
|
||||
|
||||
# Make sure we have a valid index for answer_text
|
||||
last_prompt_idx = batch_indices[-1] if batch_indices else 0
|
||||
if last_prompt_idx < len(answer_text):
|
||||
print(f"\n✅ Answer:\n{answer_text[last_prompt_idx]}")
|
||||
print("\n" + "="*10 + "\n")
|
||||
|
||||
# Only try to extract if r1_extract_xml_answer is defined
|
||||
if 'r1_extract_xml_answer' in globals():
|
||||
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
|
||||
print("\n" + "="*35 + "\n")
|
||||
|
||||
mx.metal.clear_cache()
|
||||
|
||||
return loss, sequence_lengths.sum(), metrics
|
||||
|
Loading…
Reference in New Issue
Block a user