From 978deab589cda54ad750df8c28b833ef36354731 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Tue, 11 Feb 2025 17:48:42 +0100 Subject: [PATCH] small fix --- llms/mlx_lm/tuner/grpo_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index da42da2b..4a1e6bbf 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -117,7 +117,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): end_sequence = tokenizer.encode("") end_sequence_length = len(end_sequence) - output = mx.zeros((prompt.shape[1] + max_tokens,)) + output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32) output[:prompt.shape[1]] = prompt[0] current_length = prompt.shape[1] @@ -126,7 +126,7 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): if temperature > 0: logits /= temperature logprobs = logits - mx.logsumexp(logits, keepdims=True) - return mx.random.categorical(logprobs[None, :])[0] + return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0] for _ in range(max_tokens): current_input = output[:current_length][None, :]