batching fix

This commit is contained in:
Goekdeniz-Guelmez 2025-02-28 16:02:40 +01:00
parent a04eb02257
commit 15d53279ae

View File

@ -58,51 +58,88 @@ def generate_grpo(
group_size, group_size,
is_training=False, is_training=False,
end_token: str = "</answer>", end_token: str = "</answer>",
temperature: float = 0.8 temperature: float = 0.8,
batch_size: int = 1
): ):
if len(prompts.shape) == 1: if len(prompts.shape) == 1:
prompts = prompts[None, :] prompts = prompts[None, :]
if prompts.shape[1] == 0: if prompts.shape[1] == 0:
return None return None
batch_size = prompts.shape[0] * group_size
total_samples = prompts.shape[0] * group_size
expanded_prompts = mx.repeat(prompts, group_size, axis=0) expanded_prompts = mx.repeat(prompts, group_size, axis=0)
end_sequence = mx.array(tokenizer.encode(end_token)) end_sequence = mx.array(tokenizer.encode(end_token))
results = [] results = []
mx.eval(expanded_prompts) mx.eval(expanded_prompts)
try: try:
for idx in range(batch_size): # Process in batches
current_tokens = [] for batch_start in range(0, total_samples, batch_size):
batch_end = min(batch_start + batch_size, total_samples)
batch_results = []
if is_training: if is_training:
current_input = expanded_prompts[idx] # Training mode with batched processing
prompt_cache = cache.make_prompt_cache(model) batch_inputs = expanded_prompts[batch_start:batch_end]
logits = model(current_input[None], cache=prompt_cache)[:, -1] prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)]
mx.eval(logits, prompt_cache)
while len(current_tokens) < max_tokens: # Initial forward pass for all prompts in batch
logits_temp = logits / temperature batch_logits = []
for i, prompt in enumerate(batch_inputs):
logits = model(prompt[None], cache=prompt_caches[i])[:, -1]
batch_logits.append(logits)
mx.eval(batch_logits, prompt_caches)
# Track tokens for each sequence in the batch
batch_tokens = [[] for _ in range(batch_end - batch_start)]
active_indices = list(range(batch_end - batch_start))
# Generate tokens until all sequences are complete
while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens:
next_active = []
for idx in active_indices:
logits_temp = batch_logits[idx] / temperature
probs = nn.softmax(logits_temp, axis=-1) probs = nn.softmax(logits_temp, axis=-1)
next_token = mx.random.categorical(logits_temp) next_token = mx.random.categorical(logits_temp)
token = next_token.item() token = next_token.item()
test_sequence = current_tokens + [token]
if (len(test_sequence) >= len(end_sequence) and test_sequence = batch_tokens[idx] + [token]
is_end = (len(test_sequence) >= len(end_sequence) and
mx.array_equal( mx.array_equal(
mx.array(test_sequence[-len(end_sequence):]), mx.array(test_sequence[-len(end_sequence):]),
end_sequence end_sequence
)): ))
current_tokens.append(token)
break batch_tokens[idx].append(token)
if token == tokenizer.eos_token_id:
break if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens:
current_tokens.append(token) # This sequence is done
current_input = mx.array([token]) pass
logits = model(current_input[None], cache=prompt_cache)[:, -1]
mx.eval(current_input, logits, probs, next_token, token)
else: else:
# Continue with this sequence
next_active.append(idx)
current_input = mx.array([token])
batch_logits[idx] = model(current_input[None], cache=prompt_caches[idx])[:, -1]
mx.eval([batch_logits[idx] for idx in next_active])
active_indices = next_active
# Add batch results to overall results
for tokens in batch_tokens:
if tokens:
results.append(mx.array(tokens))
else:
# Non-training mode with batched processing
for idx in range(batch_start, batch_end):
current_tokens = []
generator = generate_step( generator = generate_step(
expanded_prompts[idx], expanded_prompts[idx],
model, model,
max_tokens=max_tokens, max_tokens=max_tokens,
sampler=lambda x: mx.random.categorical(x / temperature) sampler=lambda x: mx.random.categorical(x / temperature)
) )
for token, _ in generator: for token, _ in generator:
test_sequence = current_tokens + [token] test_sequence = current_tokens + [token]
if (len(test_sequence) >= len(end_sequence) and if (len(test_sequence) >= len(end_sequence) and
@ -116,9 +153,12 @@ def generate_grpo(
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
current_tokens.append(token) current_tokens.append(token)
if current_tokens: if current_tokens:
results.append(mx.array(current_tokens)) results.append(mx.array(current_tokens))
mx.metal.clear_cache() mx.metal.clear_cache()
mx.eval(results) mx.eval(results)
return results return results
@ -151,24 +191,39 @@ def grpo_loss(
ref_model, ref_model,
tokenizer, tokenizer,
batch, batch,
reward_funcs=None, reward_funcs: Optional[List[RewardFunctions]] = None,
beta=0.1, beta: float =0.1,
group_size=4, group_size: int = 4,
epsilon=1e-4, epsilon: float = 1e-4,
max_tokens=64, max_tokens: int = 64,
temperature=1.0, temperature: float = 0.8,
reward_weights=None, reward_weights: Optional[List[float]] = None,
is_validation=False is_validation: bool = False,
batch_size: int = 1
): ):
prompt_tokens, _, prompt_text, answer_text = batch prompt_tokens, _, prompt_text, answer_text = batch
batch_size = len(prompt_tokens) total_samples = len(prompt_tokens)
all_completions = [] all_completions = []
all_completion_texts = [] all_completion_texts = []
batch_indices = [] # Keep track of which batch each completion belongs to
for i in range(0, batch_size, batch_size): # Process in smaller batches
batch_prompts = prompt_tokens[i:i+batch_size] for i in range(0, total_samples, batch_size):
prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts]) # Get actual batch size for this iteration (might be smaller for the last batch)
current_batch_size = min(batch_size, total_samples - i)
batch_prompts = prompt_tokens[i:i+current_batch_size]
# Pad sequences to the same length
max_prompt_len = max(len(p) for p in batch_prompts)
padded_prompts = []
for prompt in batch_prompts:
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
padded_prompts.append(prompt + padding)
# Convert to tensor
prompt_tensor = mx.array(padded_prompts)
try: try:
if is_validation: if is_validation:
@ -178,7 +233,8 @@ def grpo_loss(
max_tokens, max_tokens,
tokenizer, tokenizer,
group_size, group_size,
temperature=temperature temperature=temperature,
batch_size=current_batch_size
) )
model.train() model.train()
else: else:
@ -189,10 +245,16 @@ def grpo_loss(
tokenizer, tokenizer,
group_size, group_size,
is_training=True, is_training=True,
temperature=temperature temperature=temperature,
batch_size=current_batch_size
) )
if completions is not None: if completions is not None:
for completion_ids in completions: for j, completion_ids in enumerate(completions):
# Calculate which prompt this completion belongs to
prompt_idx = i + (j // group_size)
if prompt_idx < total_samples: # Make sure we don't go out of bounds
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist()) completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids) all_completions.append(completion_ids)
all_completion_texts.append(completion_text) all_completion_texts.append(completion_text)
@ -203,12 +265,49 @@ def grpo_loss(
mx.metal.clear_cache() mx.metal.clear_cache()
# If we didn't generate any completions, return early
if not all_completions:
print("No completions were generated. Returning zero loss.")
dummy_loss = mx.zeros(1)
dummy_metrics = {
'total_rewards_mean': mx.zeros(1),
'total_rewards_std': mx.zeros(1),
'kl': mx.zeros(1)
}
return dummy_loss, mx.array(0), dummy_metrics
# Create expanded prompts and answers based on actual generated completions
expanded_answers = [] expanded_answers = []
expanded_prompts = [] expanded_prompts = []
for i in range(batch_size):
expanded_answers.extend([answer_text[i]] * group_size)
expanded_prompts.extend([prompt_text[i]] * group_size)
# Group completions by their original prompt
unique_prompt_indices = sorted(set(batch_indices))
grouped_completions = {idx: [] for idx in unique_prompt_indices}
for i, completion_idx in enumerate(batch_indices):
grouped_completions[completion_idx].append(i)
# Rebuild completions in the correct order
ordered_completions = []
ordered_completion_texts = []
ordered_batch_indices = []
for prompt_idx in unique_prompt_indices:
completion_indices = grouped_completions[prompt_idx]
for idx in completion_indices:
ordered_completions.append(all_completions[idx])
ordered_completion_texts.append(all_completion_texts[idx])
ordered_batch_indices.append(prompt_idx)
# Add corresponding prompt and answer
expanded_prompts.append(prompt_text[prompt_idx])
expanded_answers.append(answer_text[prompt_idx])
all_completions = ordered_completions
all_completion_texts = ordered_completion_texts
batch_indices = ordered_batch_indices
# Continue with the rest of the function
max_length = max(ids.shape[0] for ids in all_completions) max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = [] padded_completions = []
attention_masks = [] attention_masks = []
@ -231,7 +330,6 @@ def grpo_loss(
# Current policy probabilities # Current policy probabilities
token_log_probs = get_per_token_logps(model, inputs, lengths) token_log_probs = get_per_token_logps(model, inputs, lengths)
mx.eval(token_log_probs) mx.eval(token_log_probs)
if ref_model is None: if ref_model is None:
@ -282,11 +380,31 @@ def grpo_loss(
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
# Reshape rewards and compute advantages # Get number of unique prompts
rewards_reshaped = rewards.reshape(batch_size, group_size) num_unique_prompts = len(unique_prompt_indices)
mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1)
std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) # Reshape rewards based on actual groups
advantages = (rewards - mean_rewards) / (std_rewards + epsilon) rewards_by_prompt = [[] for _ in range(num_unique_prompts)]
for i, prompt_idx in enumerate(batch_indices):
prompt_position = unique_prompt_indices.index(prompt_idx)
rewards_by_prompt[prompt_position].append(rewards[i])
# Calculate advantages within each group
advantages = mx.zeros_like(rewards)
for i, prompt_rewards in enumerate(rewards_by_prompt):
if len(prompt_rewards) > 1: # Only normalize if we have multiple samples
prompt_rewards = mx.array(prompt_rewards)
mean_reward = mx.mean(prompt_rewards)
std_reward = mx.std(prompt_rewards)
# Find indices for this prompt
indices = [j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i]]
for j, idx in enumerate(indices):
advantages[idx] = (prompt_rewards[j] - mean_reward) / (std_reward + epsilon)
else:
# If only one sample, advantage is 0
idx = batch_indices.index(unique_prompt_indices[i])
advantages[idx] = 0.0
# Compute KL divergence using Schulman's approximator # Compute KL divergence using Schulman's approximator
kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1 kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1
@ -320,23 +438,35 @@ def grpo_loss(
reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards)
reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) reward_metrics[f'{func_name}_std'] = mx.std(func_rewards)
grouped_rewards_mean = mx.array([mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt])
grouped_rewards_std = mx.array([mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt])
metrics = { metrics = {
'total_rewards_mean': mx.mean(rewards), 'total_rewards_mean': mx.mean(rewards),
'total_rewards_std': mx.std(rewards), 'total_rewards_std': mx.std(rewards),
'grouped_rewards_mean': mx.mean(rewards_reshaped), 'grouped_rewards_mean': mx.mean(grouped_rewards_mean),
'grouped_rewards_std': mx.std(rewards_reshaped), 'grouped_rewards_std': mx.mean(grouped_rewards_std),
'kl': mean_kl, 'kl': mean_kl,
**reward_metrics **reward_metrics
} }
if is_validation: if is_validation and all_completion_texts:
print("\n=== Validation Sample Details ===") print("\n=== Validation Sample Details ===")
print(f"\n📝 Generation:\n{all_completion_texts[-1]}") print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
print("\n" + "="*10 + "\n") print("\n" + "="*10 + "\n")
print(f"\n✅ Answer:\n{answer_text[-1]}")
# Make sure we have a valid index for answer_text
last_prompt_idx = batch_indices[-1] if batch_indices else 0
if last_prompt_idx < len(answer_text):
print(f"\n✅ Answer:\n{answer_text[last_prompt_idx]}")
print("\n" + "="*10 + "\n") print("\n" + "="*10 + "\n")
# Only try to extract if r1_extract_xml_answer is defined
if 'r1_extract_xml_answer' in globals():
print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}") print(f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}")
print("\n" + "="*30 + "\n") print("\n" + "="*35 + "\n")
mx.metal.clear_cache() mx.metal.clear_cache()
return loss, sequence_lengths.sum(), metrics return loss, sequence_lengths.sum(), metrics