eos token return fix

This commit is contained in:
Goekdeniz-Guelmez 2025-03-05 14:00:51 +01:00
parent 2bde97fe13
commit 1f89453295

View File

@ -225,6 +225,11 @@ def generate_grpo(
)
):
print(token)
# Check for EOS token
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
# Check for end token
@ -233,10 +238,6 @@ def generate_grpo(
):
break
# Check for EOS token
if token == tokenizer.eos_token_id:
break
# Check if we've reached the maximum number of tokens
if i >= max_tokens - 1:
break
@ -305,7 +306,6 @@ def grpo_loss(
max_tokens: int = 64,
temperature: float = 0.8,
reward_weights: Optional[List[float]] = None,
is_validation: bool = False,
batch_size: int = 1,
):
prompt_tokens, _, prompt_text, answer_text = batch
@ -366,7 +366,7 @@ def grpo_loss(
continue
# Restore original training state if we're not in validation mode
if not is_validation and was_training:
if was_training:
model.train()
mx.metal.clear_cache()
@ -712,8 +712,7 @@ def evaluate_grpo(
epsilon=epsilon,
ref_model=ref_model,
temperature=temperature,
max_tokens=max_tokens,
is_validation=True,
max_tokens=max_tokens
)
all_losses += losses * toks