diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index a9ba4b01..f995a05c 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -114,29 +114,25 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): prompt = prompt[None, :] if prompt.shape[1] == 0: return None - - # Get "" token ids + end_sequence = tokenizer.encode("") end_sequence_length = len(end_sequence) - - # Use int32 for token ids output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32) output[:prompt.shape[1]] = prompt[0] current_length = prompt.shape[1] - + try: + def sample(logits): + if temperature > 0: + logits /= temperature + logprobs = logits - mx.logsumexp(logits, keepdims=True) + return mx.random.categorical(logprobs[None, :]).astype(mx.int32)[0] + for _ in range(max_tokens): current_input = output[:current_length][None, :] logits = model(current_input) token_logits = logits[0, -1] - - if temperature > 0: - token_logits /= temperature - - probs = mx.softmax(token_logits) - next_token = mx.random.categorical(probs[None, :]).astype(mx.int32) - next_token = next_token[0] - + next_token = sample(token_logits) token_value = next_token.item() output[current_length] = token_value current_length += 1 @@ -146,16 +142,19 @@ def generate_grpo(model, prompt, max_tokens, tokenizer, temperature): if current_length >= end_sequence_length: last_tokens = output[current_length - end_sequence_length:current_length].tolist() + # print(f"Last tokens: {last_tokens}") + # print(f"Decoded text: {tokenizer.decode(last_tokens)}") + # print(f"Target sequence: {end_sequence}") if last_tokens == end_sequence: break - + if current_length > prompt.shape[1]: - result = output[:current_length] - return result + return output[:current_length] except Exception as e: print(f"Generation error: {str(e)}") return None + return None