mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-12 20:26:45 +08:00
eos token return fix
This commit is contained in:
parent
2bde97fe13
commit
1f89453295
@ -225,6 +225,11 @@ def generate_grpo(
|
|||||||
)
|
)
|
||||||
):
|
):
|
||||||
print(token)
|
print(token)
|
||||||
|
|
||||||
|
# Check for EOS token
|
||||||
|
if token == tokenizer.eos_token_id:
|
||||||
|
break
|
||||||
|
|
||||||
current_tokens.append(token)
|
current_tokens.append(token)
|
||||||
|
|
||||||
# Check for end token
|
# Check for end token
|
||||||
@ -233,10 +238,6 @@ def generate_grpo(
|
|||||||
):
|
):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Check for EOS token
|
|
||||||
if token == tokenizer.eos_token_id:
|
|
||||||
break
|
|
||||||
|
|
||||||
# Check if we've reached the maximum number of tokens
|
# Check if we've reached the maximum number of tokens
|
||||||
if i >= max_tokens - 1:
|
if i >= max_tokens - 1:
|
||||||
break
|
break
|
||||||
@ -305,7 +306,6 @@ def grpo_loss(
|
|||||||
max_tokens: int = 64,
|
max_tokens: int = 64,
|
||||||
temperature: float = 0.8,
|
temperature: float = 0.8,
|
||||||
reward_weights: Optional[List[float]] = None,
|
reward_weights: Optional[List[float]] = None,
|
||||||
is_validation: bool = False,
|
|
||||||
batch_size: int = 1,
|
batch_size: int = 1,
|
||||||
):
|
):
|
||||||
prompt_tokens, _, prompt_text, answer_text = batch
|
prompt_tokens, _, prompt_text, answer_text = batch
|
||||||
@ -366,7 +366,7 @@ def grpo_loss(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Restore original training state if we're not in validation mode
|
# Restore original training state if we're not in validation mode
|
||||||
if not is_validation and was_training:
|
if was_training:
|
||||||
model.train()
|
model.train()
|
||||||
mx.metal.clear_cache()
|
mx.metal.clear_cache()
|
||||||
|
|
||||||
@ -712,8 +712,7 @@ def evaluate_grpo(
|
|||||||
epsilon=epsilon,
|
epsilon=epsilon,
|
||||||
ref_model=ref_model,
|
ref_model=ref_model,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens
|
||||||
is_validation=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
all_losses += losses * toks
|
all_losses += losses * toks
|
||||||
|
Loading…
Reference in New Issue
Block a user