eos token return fix

This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 14:00:51 +01:00
parent 2bde97fe13
commit 1f89453295

View File

@ -225,6 +225,11 @@ def generate_grpo(
) )
): ):
print(token) print(token)
# Check for EOS token
if token == tokenizer.eos_token_id:
break
current_tokens.append(token) current_tokens.append(token)
# Check for end token # Check for end token
@ -233,10 +238,6 @@ def generate_grpo(
): ):
break break
# Check for EOS token
if token == tokenizer.eos_token_id:
break
# Check if we've reached the maximum number of tokens # Check if we've reached the maximum number of tokens
if i >= max_tokens - 1: if i >= max_tokens - 1:
break break
@ -305,7 +306,6 @@ def grpo_loss(
max_tokens: int = 64, max_tokens: int = 64,
temperature: float = 0.8, temperature: float = 0.8,
reward_weights: Optional[List[float]] = None, reward_weights: Optional[List[float]] = None,
is_validation: bool = False,
batch_size: int = 1, batch_size: int = 1,
): ):
prompt_tokens, _, prompt_text, answer_text = batch prompt_tokens, _, prompt_text, answer_text = batch
@ -366,7 +366,7 @@ def grpo_loss(
continue continue
# Restore original training state if we're not in validation mode # Restore original training state if we're not in validation mode
if not is_validation and was_training: if was_training:
model.train() model.train()
mx.metal.clear_cache() mx.metal.clear_cache()
@ -712,8 +712,7 @@ def evaluate_grpo(
epsilon=epsilon, epsilon=epsilon,
ref_model=ref_model, ref_model=ref_model,
temperature=temperature, temperature=temperature,
max_tokens=max_tokens, max_tokens=max_tokens
is_validation=True,
) )
all_losses += losses * toks all_losses += losses * toks