seperate functions

This commit is contained in:
Goekdeniz-Guelmez
2025-03-05 15:28:12 +01:00
parent d723ddfeda
commit f13a0d04ca
2 changed files with 269 additions and 72 deletions

View File

@@ -297,6 +297,79 @@ def get_per_token_logps(model: nn.Module, inputs, lengths):
return per_token_logps
def generate_without_gradients(
model: nn.Module,
tokenizer,
prompt_tokens,
max_tokens: int,
group_size: int,
temperature: float = 0.8,
batch_size: int = 1
):
"""Generate completions without tracking gradients"""
# Store original state
was_training = model.training
# Force eval mode
model.eval()
# Prepare prompts
total_samples = len(prompt_tokens)
all_completions = []
all_completion_texts = []
batch_indices = []
# Process in smaller batches
for i in range(0, total_samples, batch_size):
current_batch_size = min(batch_size, total_samples - i)
batch_prompts = prompt_tokens[i : i + current_batch_size]
# Pad sequences to the same length
max_prompt_len = max(len(p) for p in batch_prompts)
padded_prompts = []
for prompt in batch_prompts:
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
padded_prompts.append(prompt + padding)
# Convert to tensor and explicitly stop gradient
prompt_tensor = mx.stop_gradient(mx.array(padded_prompts))
try:
completions = generate_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer,
group_size,
temperature=temperature,
batch_size=current_batch_size,
)
if completions is not None:
for j, completion_ids in enumerate(completions):
prompt_idx = i + (j // group_size)
if prompt_idx < total_samples:
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
mx.eval(completion_ids)
except Exception as e:
print(f"Generation error: {e}")
continue
# Restore original state
if was_training:
model.train()
mx.metal.clear_cache()
return all_completions, all_completion_texts, batch_indices
def grpo_loss(
model,
ref_model,
@@ -313,70 +386,17 @@ def grpo_loss(
is_validation: bool = False
):
prompt_tokens, _, prompt_text, answer_text = batch
total_samples = len(prompt_tokens)
all_completions = []
all_completion_texts = []
batch_indices = [] # Keep track of which batch each completion belongs to
# Store original training state
was_training = model.training
print(f"Was model in training mode: {was_training}")
# Set model to eval mode for generation
model.eval()
print(f"Is model now in training mode: {model.training}")
# Process in smaller batches
for i in range(0, total_samples, batch_size):
# Get actual batch size for this iteration (might be smaller for the last batch)
current_batch_size = min(batch_size, total_samples - i)
batch_prompts = prompt_tokens[i : i + current_batch_size]
# Pad sequences to the same length
max_prompt_len = max(len(p) for p in batch_prompts)
padded_prompts = []
for prompt in batch_prompts:
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
padded_prompts.append(prompt + padding)
# Convert to tensor
prompt_tensor = mx.array(padded_prompts)
prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input
try:
mx.metal.clear_cache()
completions = generate_grpo(
model,
prompt_tensor,
max_tokens,
tokenizer,
group_size,
temperature=temperature,
batch_size=current_batch_size,
)
if completions is not None:
for j, completion_ids in enumerate(completions):
# Calculate which prompt this completion belongs to
prompt_idx = i + (j // group_size)
if prompt_idx < total_samples: # Make sure we don't go out of bounds
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(completion_ids)
all_completion_texts.append(completion_text)
mx.eval(completion_ids)
except Exception as e:
print(f"Generation error: {e}")
print(f"Is model in training mode after generation: {model.training}")
continue
# Restore original training state if we're not in validation mode
if was_training:
model.train()
mx.metal.clear_cache()
# Generate completions without tracking gradients
all_completions, all_completion_texts, batch_indices = generate_without_gradients(
model=model,
tokenizer=tokenizer,
prompt_tokens=prompt_tokens,
max_tokens=max_tokens,
group_size=group_size,
temperature=temperature,
batch_size=batch_size
)
# If we didn't generate any completions, return early
if not all_completions:
@@ -415,25 +435,30 @@ def grpo_loss(
all_completion_texts = ordered_completion_texts
batch_indices = ordered_batch_indices
# Continue with the rest of the function
# Create new input tensors for the model to compute logits with gradient tracking
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
attention_masks = []
for completion_ids in all_completions:
padding_length = max_length - completion_ids.shape[0]
# Convert the pre-generated completion to a regular tensor (not stop_gradient)
# This allows gradients to flow during the loss computation phase
completion_tensor = mx.array(completion_ids.tolist())
padding_length = max_length - completion_tensor.shape[0]
if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_ids.dtype)
padded_ids = mx.concatenate([completion_ids, padding])
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
padded_ids = mx.concatenate([completion_tensor, padding])
mask = mx.concatenate(
[mx.ones_like(completion_ids), mx.zeros_like(padding)]
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
)
else:
padded_ids = completion_ids
mask = mx.ones_like(completion_ids)
padded_ids = completion_tensor
mask = mx.ones_like(completion_tensor)
padded_completions.append(padded_ids)
attention_masks.append(mask)
# Rest of the function remains the same
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
@@ -721,7 +746,6 @@ def evaluate_grpo(
ref_model=ref_model,
temperature=temperature,
max_tokens=max_tokens,
is_validation=True
)
all_losses += losses * toks