This commit is contained in:
Goekdeniz-Guelmez 2025-02-05 08:47:03 +01:00
parent 0a09a93454
commit 2a8e6f6e44

View File

@ -125,6 +125,11 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
if prompt.shape[1] == 0:
return None
# Get "</answer>" token ids
end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence)
# Use int32 for token ids
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
output[:prompt.shape[1]] = prompt[0]
current_length = prompt.shape[1]
@ -146,9 +151,16 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0):
output[current_length] = token_value
current_length += 1
# Check for EOS token
if token_value == tokenizer.eos_token_id:
break
# Check for "</answer>" sequence
if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
if last_tokens == end_sequence:
break
if current_length > prompt.shape[1]:
result = output[:current_length]
return result