diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 3b5dc2e9..5661085e 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -36,41 +36,6 @@ class GRPOTrainingArgs(TrainingArgs): ) -def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): - if len(prompt.shape) == 1: - prompt = prompt[None, :] - - generated = [] - current_prompt = prompt[0] - - for _ in range(max_tokens): - current_batch = current_prompt[None, :] - logits = model(current_batch) - token_logits = logits[0, -1] - - if temperature > 0: - token_logits = token_logits / temperature - - probs = mx.softmax(token_logits) - next_token = mx.random.categorical(probs[None, :]) - next_token = next_token[0] - mx.eval(next_token) - - token_value = next_token.item() - generated.append(next_token) - - current_prompt = mx.concatenate([current_prompt, next_token[None]]) - if token_value == tokenizer.eos_token_id: - break - - if not generated: - return prompt[0] - - result = mx.concatenate([prompt[0], mx.stack(generated)]) - mx.eval(result) - return result - - def r1_extract_xml_answer(text: str) -> str: """Extracts the answer from an XML formatted text string.""" try: @@ -154,9 +119,48 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li return scores +def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): + if len(prompt.shape) == 1: + prompt = prompt[None, :] + if prompt.shape[1] == 0: + return None + + output = mx.zeros((prompt.shape[1] + max_tokens,), dtype=mx.int32) + output[:prompt.shape[1]] = prompt[0] + current_length = prompt.shape[1] + + try: + for _ in range(max_tokens): + current_input = output[:current_length][None, :] + logits = model(current_input) + token_logits = logits[0, -1] + + if temperature > 0: + token_logits /= temperature + + probs = mx.softmax(token_logits) + next_token = mx.random.categorical(probs[None, :]).astype(mx.int32) + next_token = next_token[0] + + token_value = next_token.item() + output[current_length] = token_value + current_length += 1 + + if token_value == tokenizer.eos_token_id: + break + + if current_length > prompt.shape[1]: + result = output[:current_length] + return result + + except Exception as e: + print(f"Generation error: {str(e)}") + return None + return None + + def get_per_token_logps(model, inputs, lengths): - # Get logits from model - logits = model(inputs).astype(mx.float32) # [batch_size, seq_len, vocab_size] + logits = model(inputs).astype(mx.float16) # [batch_size, seq_len, vocab_size] # Remove last position as it corresponds to the next token prediction logits = logits[:, :-1, :] # [batch_size, seq_len-1, vocab_size] targets = inputs[:, 1:] # Shift inputs to get targets @@ -182,6 +186,7 @@ def get_per_token_logps(model, inputs, lengths): ).squeeze(-1) # [seq_len] per_token_logps.append(token_log_probs) + mx.eval(logits) return per_token_logps @@ -204,22 +209,26 @@ def grpo_loss( all_completions = [] all_completion_texts = [] - for prompt in prompt_tokens: - prompt_tensor = mx.array(prompt) - - for _ in range(group_size): - try: - completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature) - if completion_ids is None: + for i in range(0, batch_size, batch_size): + batch_prompts = prompt_tokens[i:i+batch_size] + for prompt in batch_prompts: + prompt_tensor = mx.array(prompt) + for _ in range(group_size): + try: + completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature) + if completion_ids is not None: + completion_text = tokenizer.decode(completion_ids.tolist()) + all_completions.append(completion_ids) + all_completion_texts.append(completion_text) + + # Clear completion tensors + mx.eval(completion_ids) + del completion_ids + except Exception as e: + print(f"Generation error: {e}") continue - completion_text = tokenizer.decode(completion_ids.tolist()) - all_completions.append(completion_ids) - all_completion_texts.append(completion_text) - - except Exception as e: - print(f"Generation error: {e}") - continue + mx.metal.clear_cache() # Prepare inputs expanded_answers = [] @@ -250,6 +259,10 @@ def grpo_loss( # Current policy probabilities token_log_probs = get_per_token_logps(model, inputs, lengths) + + mx.eval(token_log_probs) + mx.metal.clear_cache() + # Reference policy probabilities if ref_model is not None: @@ -263,7 +276,7 @@ def grpo_loss( for i in range(len(token_log_probs)): seq_len = token_log_probs[i].shape[0] - padding = mx.zeros((max_len - seq_len,), dtype=mx.float32) + padding = mx.zeros((max_len - seq_len,), dtype=mx.float16) padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) @@ -330,6 +343,7 @@ def grpo_loss( 'kl': mean_kl, **reward_metrics } + mx.metal.clear_cache() return loss, sequence_lengths.sum(), metrics