mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-12 04:06:42 +08:00
eos token return fix
This commit is contained in:
parent
2bde97fe13
commit
1f89453295
@ -225,6 +225,11 @@ def generate_grpo(
|
||||
)
|
||||
):
|
||||
print(token)
|
||||
|
||||
# Check for EOS token
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
current_tokens.append(token)
|
||||
|
||||
# Check for end token
|
||||
@ -233,10 +238,6 @@ def generate_grpo(
|
||||
):
|
||||
break
|
||||
|
||||
# Check for EOS token
|
||||
if token == tokenizer.eos_token_id:
|
||||
break
|
||||
|
||||
# Check if we've reached the maximum number of tokens
|
||||
if i >= max_tokens - 1:
|
||||
break
|
||||
@ -305,7 +306,6 @@ def grpo_loss(
|
||||
max_tokens: int = 64,
|
||||
temperature: float = 0.8,
|
||||
reward_weights: Optional[List[float]] = None,
|
||||
is_validation: bool = False,
|
||||
batch_size: int = 1,
|
||||
):
|
||||
prompt_tokens, _, prompt_text, answer_text = batch
|
||||
@ -366,7 +366,7 @@ def grpo_loss(
|
||||
continue
|
||||
|
||||
# Restore original training state if we're not in validation mode
|
||||
if not is_validation and was_training:
|
||||
if was_training:
|
||||
model.train()
|
||||
mx.metal.clear_cache()
|
||||
|
||||
@ -712,8 +712,7 @@ def evaluate_grpo(
|
||||
epsilon=epsilon,
|
||||
ref_model=ref_model,
|
||||
temperature=temperature,
|
||||
max_tokens=max_tokens,
|
||||
is_validation=True,
|
||||
max_tokens=max_tokens
|
||||
)
|
||||
|
||||
all_losses += losses * toks
|
||||
|
Loading…
Reference in New Issue
Block a user