This commit is contained in:
Goekdeniz-Guelmez 2025-02-05 15:02:12 +01:00
parent 0a19522ec4
commit bcfa55d882

View File

@ -114,29 +114,25 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
prompt = prompt[None, :]
if prompt.shape[1] == 0:
return None
# Get "</answer>" token ids
end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence)
# Use int32 for token ids
output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
output[:prompt.shape[1]] = prompt[0]
current_length = prompt.shape[1]
try:
def sample(logits):
if temperature > 0:
logits /= temperature
logprobs = logits - mx.logsumexp(logits, keepdims=True)
return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
for _ in range(max_tokens):
current_input = output[:current_length][None, :]
logits = model(current_input)
token_logits = logits[0, -1]
if temperature > 0:
token_logits /= temperature
probs = mx.softmax(token_logits)
next_token = mx.random.categorical(probs[None, :]).astype(mx.int32)
next_token = next_token[0]
next_token = sample(token_logits)
token_value = next_token.item()
output[current_length] = token_value
current_length += 1
@ -146,16 +142,19 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
if current_length >= end_sequence_length:
last_tokens = output[current_length - end_sequence_length:current_length].tolist()
# print(f"Last tokens: {last_tokens}")
# print(f"Decoded text: {tokenizer.decode(last_tokens)}")
# print(f"Target sequence: {end_sequence}")
if last_tokens == end_sequence:
break
if current_length > prompt.shape[1]:
result = output[:current_length]
return result
return output[:current_length]
except Exception as e:
print(f"Generation error: {str(e)}")
return None
return None