From 2a8e6f6e4461af787b7980b30e3d5f38ba9ed390 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 5 Feb 2025 08:47:03 +0100 Subject: [PATCH] udpate --- llms/mlx_lm/tuner/grpo_trainer.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 5661085e..0210b44a 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -125,6 +125,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): if prompt.shape[1] == 0: return None + # Get "" token ids + end_sequence = tokenizer.encode("") + end_sequence_length = len(end_sequence) + + # Use int32 for token ids output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32) output[:prompt.shape[1]] = prompt[0] current_length = prompt.shape[1] @@ -146,9 +151,16 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): output[current_length] = token_value current_length += 1 + # Check for EOS token if token_value == tokenizer.eos_token_id: break + # Check for "" sequence + if current_length >= end_sequence_length: + last_tokens = output[current_length - end_sequence_length:current_length].tolist() + if last_tokens == end_sequence: + break + if current_length > prompt.shape[1]: result = output[:current_length] return result