This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 14:25:55 +01:00
parent 1f89453295
commit 2d2f39f96e

View File

@ -175,6 +175,8 @@ def generate_grpo(
try: try:
import time import time
model.freeze()
start_time = time.time() start_time = time.time()
if len(prompts.shape) == 1: if len(prompts.shape) == 1:
@ -229,8 +231,9 @@ def generate_grpo(
# Check for EOS token # Check for EOS token
if token == tokenizer.eos_token_id: if token == tokenizer.eos_token_id:
break break
current_tokens.append(token) current_tokens.append(token)
mx.eval(current_tokens[-1])
# Check for end token # Check for end token
if len(current_tokens) >= len(end_sequence) and mx.array_equal( if len(current_tokens) >= len(end_sequence) and mx.array_equal(
@ -268,7 +271,9 @@ def generate_grpo(
) )
print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec") print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec")
results = [mx.stop_gradient(r) for r in results]
mx.eval(results) mx.eval(results)
model.unfreeze()
return results return results
except Exception as e: except Exception as e:
@ -307,6 +312,7 @@ def grpo_loss(
temperature: float = 0.8, temperature: float = 0.8,
reward_weights: Optional[List[float]] = None, reward_weights: Optional[List[float]] = None,
batch_size: int = 1, batch_size: int = 1,
is_validation: bool = False
): ):
prompt_tokens, _, prompt_text, answer_text = batch prompt_tokens, _, prompt_text, answer_text = batch
total_samples = len(prompt_tokens) total_samples = len(prompt_tokens)
@ -321,6 +327,7 @@ def grpo_loss(
# Set model to eval mode for generation # Set model to eval mode for generation
model.eval() model.eval()
print(f"Is model now in training mode: {model.training}")
# Process in smaller batches # Process in smaller batches
for i in range(0, total_samples, batch_size): for i in range(0, total_samples, batch_size):
@ -340,6 +347,7 @@ def grpo_loss(
prompt_tensor = mx.array(padded_prompts) prompt_tensor = mx.array(padded_prompts)
try: try:
mx.metal.clear_cache()
completions = generate_grpo( completions = generate_grpo(
model, model,
prompt_tensor, prompt_tensor,
@ -363,6 +371,7 @@ def grpo_loss(
mx.eval(completion_ids) mx.eval(completion_ids)
except Exception as e: except Exception as e:
print(f"Generation error: {e}") print(f"Generation error: {e}")
print(f"Is model in training mode after generation: {model.training}")
continue continue
# Restore original training state if we're not in validation mode # Restore original training state if we're not in validation mode
@ -712,7 +721,8 @@ def evaluate_grpo(
epsilon=epsilon, epsilon=epsilon,
ref_model=ref_model, ref_model=ref_model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens max_tokens=max_tokens,
is_validation=True
) )
all_losses += losses * toks all_losses += losses * toks