seperate functions

This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 15:28:12 +01:00
parent d723ddfeda
commit f13a0d04ca
2 changed files with 269 additions and 72 deletions

View File

@ -453,4 +453,177 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
def compute_grpo_loss_and_grad(
model,
ref_model,
completion_tensors,
prompt_texts,
answer_texts,
beta=0.1,
epsilon=1e-4,
reward_funcs=None,
reward_weights=None
):
"""
Compute GRPO loss and gradients using pre-generated completions.
Args:
model: The policy model
ref_model: The reference model
completion_tensors: List of tensors containing generated completions
prompt_texts: List of prompt texts
answer_texts: List of answer texts
beta: KL penalty coefficient
epsilon: Numerical stability constant
reward_funcs: List of reward functions
reward_weights: Optional weights for reward functions
"""
# Ensure model is in training mode for gradient computation
model.train()
# Get completion texts for reward calculation
completion_texts = [tokenizer.decode(comp.tolist()) for comp in completion_tensors]
# Prepare inputs for loss computation
max_length = max(tensor.shape[0] for tensor in completion_tensors)
padded_completions = []
attention_masks = []
for completion_tensor in completion_tensors:
padding_length = max_length - completion_tensor.shape[0]
if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
padded_ids = mx.concatenate([completion_tensor, padding])
mask = mx.concatenate(
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
)
else:
padded_ids = completion_tensor
mask = mx.ones_like(completion_tensor)
padded_completions.append(padded_ids)
attention_masks.append(mask)
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
# Compute log probabilities for both models
token_log_probs = get_per_token_logps(model, inputs, lengths)
if ref_model is None:
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in token_log_probs]
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in ref_token_log_probs]
# Pad log probabilities to same length
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
padded_ref_log_probs = []
for i in range(len(token_log_probs)):
seq_len = token_log_probs[i].shape[0]
padding = mx.zeros((max_len - seq_len,))
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
# Calculate rewards
all_func_rewards = []
for reward_func in reward_funcs:
func_rewards = mx.array(
reward_func(
prompts=prompt_texts,
completions=completion_texts,
answer=answer_texts,
)
)
all_func_rewards.append(func_rewards)
# Stack rewards and apply weights
rewards = mx.stack(all_func_rewards, axis=1)
if reward_weights is not None:
if len(reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
# Group rewards by prompt (assuming completions are grouped by prompt)
group_size = len(completion_tensors) // len(prompt_texts)
if len(completion_tensors) % len(prompt_texts) != 0:
raise ValueError("Number of completions must be divisible by number of prompts")
rewards_by_group = []
for i in range(0, len(rewards), group_size):
rewards_by_group.append(rewards[i:i+group_size])
# Calculate advantages
advantages = mx.zeros_like(rewards)
for i, group_rewards in enumerate(rewards_by_group):
if len(group_rewards) > 1: # Only normalize if we have multiple samples
mean_reward = mx.mean(group_rewards)
std_reward = mx.std(group_rewards)
for j in range(group_size):
idx = i * group_size + j
advantages[idx] = (group_rewards[j] - mean_reward) / (std_reward + epsilon)
else:
# If only one sample, advantage is 0
advantages[i * group_size] = 0.0
# Compute KL divergence
kl_div = (
mx.exp(ref_token_log_probs - token_log_probs)
- (ref_token_log_probs - token_log_probs)
- 1
)
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(token_log_probs - ref_token_log_probs)
# Compute per-token loss
per_token_loss = -(
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
)
# Average over tokens
sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate metrics for reporting
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
metrics = {
"total_rewards_mean": mx.mean(rewards),
"total_rewards_std": mx.std(rewards),
"kl": mean_kl,
}
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
func_rewards = all_func_rewards[i]
metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
metrics[f"{func_name}_std"] = mx.std(func_rewards)
return loss, sequence_lengths.sum(), metrics

View File

@ -297,6 +297,79 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
return per_token_logps return per_token_logps
def generate_without_gradients(
model: nn.Module,
tokenizer,
prompt_tokens,
max_tokens: int,
group_size: int,
temperature: float = 0.8,
batch_size: int = 1
):
"""Generate completions without tracking gradients"""
# Store original state
was_training = model.training
# Force eval mode
model.eval()
# Prepare prompts
total_samples = len(prompt_tokens)
all_completions = []
all_completion_texts = []
batch_indices = []
# Process in smaller batches
for i in range(0, total_samples, batch_size):
current_batch_size = min(batch_size, total_samples - i)
batch_prompts = prompt_tokens[i : i + current_batch_size]
# Pad sequences to the same length
max_prompt_len = max(len(p) for p in batch_prompts)
padded_prompts = []
for prompt in batch_prompts:
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
padded_prompts.append(prompt + padding)
# Convert to tensor and explicitly stop gradient
prompt_tensor = mx.stop_gradient(mx.array(padded_prompts))
try:
completions = generate_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer,
group_size,
temperature=temperature,
batch_size=current_batch_size,
)
if completions is not None:
for j, completion_ids in enumerate(completions):
prompt_idx = i + (j // group_size)
if prompt_idx < total_samples:
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
mx.eval(completion_ids)
except Exception as e:
print(f"Generation error: {e}")
continue
# Restore original state
if was_training:
model.train()
mx.metal.clear_cache()
return all_completions, all_completion_texts, batch_indices
def grpo_loss( def grpo_loss(
model, model,
ref_model, ref_model,
@ -313,70 +386,17 @@ def grpo_loss(
is_validation: bool = False is_validation: bool = False
): ):
prompt_tokens, _, prompt_text, answer_text = batch prompt_tokens, _, prompt_text, answer_text = batch
total_samples = len(prompt_tokens)
all_completions = [] # Generate completions without tracking gradients
all_completion_texts = [] all_completions, all_completion_texts, batch_indices = generate_without_gradients(
batch_indices = [] # Keep track of which batch each completion belongs to model=model,
tokenizer=tokenizer,
# Store original training state prompt_tokens=prompt_tokens,
was_training = model.training max_tokens=max_tokens,
print(f"Was model in training mode: {was_training}") group_size=group_size,
temperature=temperature,
# Set model to eval mode for generation batch_size=batch_size
model.eval() )
print(f"Is model now in training mode: {model.training}")
# Process in smaller batches
for i in range(0, total_samples, batch_size):
# Get actual batch size for this iteration (might be smaller for the last batch)
current_batch_size = min(batch_size, total_samples - i)
batch_prompts = prompt_tokens[i : i + current_batch_size]
# Pad sequences to the same length
max_prompt_len = max(len(p) for p in batch_prompts)
padded_prompts = []
for prompt in batch_prompts:
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
padded_prompts.append(prompt + padding)
# Convert to tensor
prompt_tensor = mx.array(padded_prompts)
prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input
try:
mx.metal.clear_cache()
completions = generate_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer,
group_size,
temperature=temperature,
batch_size=current_batch_size,
)
if completions is not None:
for j, completion_ids in enumerate(completions):
# Calculate which prompt this completion belongs to
prompt_idx = i + (j // group_size)
if prompt_idx < total_samples: # Make sure we don't go out of bounds
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
mx.eval(completion_ids)
except Exception as e:
print(f"Generation error: {e}")
print(f"Is model in training mode after generation: {model.training}")
continue
# Restore original training state if we're not in validation mode
if was_training:
model.train()
mx.metal.clear_cache()
# If we didn't generate any completions, return early # If we didn't generate any completions, return early
if not all_completions: if not all_completions:
@ -415,25 +435,30 @@ def grpo_loss(
all_completion_texts = ordered_completion_texts all_completion_texts = ordered_completion_texts
batch_indices = ordered_batch_indices batch_indices = ordered_batch_indices
# Continue with the rest of the function # Create new input tensors for the model to compute logits with gradient tracking
max_length = max(ids.shape[0] for ids in all_completions) max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = [] padded_completions = []
attention_masks = [] attention_masks = []
for completion_ids in all_completions: for completion_ids in all_completions:
padding_length = max_length - completion_ids.shape[0] # Convert the pre-generated completion to a regular tensor (not stop_gradient)
# This allows gradients to flow during the loss computation phase
completion_tensor = mx.array(completion_ids.tolist())
padding_length = max_length - completion_tensor.shape[0]
if padding_length > 0: if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype) padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
padded_ids = mx.concatenate([completion_ids, padding]) padded_ids = mx.concatenate([completion_tensor, padding])
mask = mx.concatenate( mask = mx.concatenate(
[mx.ones_like(completion_ids), mx.zeros_like(padding)] [mx.ones_like(completion_tensor), mx.zeros_like(padding)]
) )
else: else:
padded_ids = completion_ids padded_ids = completion_tensor
mask = mx.ones_like(completion_ids) mask = mx.ones_like(completion_tensor)
padded_completions.append(padded_ids) padded_completions.append(padded_ids)
attention_masks.append(mask) attention_masks.append(mask)
# Rest of the function remains the same
inputs = mx.stack(padded_completions) inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks) attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1) lengths = attention_mask.sum(axis=1)
@ -721,7 +746,6 @@ def evaluate_grpo(
ref_model=ref_model, ref_model=ref_model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens,
is_validation=True
) )
all_losses += losses * toks all_losses += losses * toks