small fix

This commit is contained in:
Goekdeniz-Guelmez 2025-02-11 17:48:42 +01:00
parent 35ecc17042
commit 978deab589

View File

@ -117,7 +117,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
end_sequence = tokenizer.encode("</answer>") end_sequence = tokenizer.encode("</answer>")
end_sequence_length = len(end_sequence) end_sequence_length = len(end_sequence)
output = mx.zeros((prompt.shape[1] + max_tokens,)) output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32)
output[:prompt.shape[1]] = prompt[0] output[:prompt.shape[1]] = prompt[0]
current_length = prompt.shape[1] current_length = prompt.shape[1]
@ -126,7 +126,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature):
if temperature > 0: if temperature > 0:
logits /= temperature logits /= temperature
logprobs = logits - mx.logsumexp(logits, keepdims=True) logprobs = logits - mx.logsumexp(logits, keepdims=True)
return mx.random.categorical(logprobs[None, :])[0] return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0]
for _ in range(max_tokens): for _ in range(max_tokens):
current_input = output[:current_length][None, :] current_input = output[:current_length][None, :]