diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 66abf99f..351ba9de 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -51,74 +51,114 @@ class GRPOTrainingArgs(TrainingArgs): def generate_grpo( - model: nn.Module, - prompts, - max_tokens, - tokenizer, - group_size, - is_training=False, - end_token: str = "", - temperature: float = 0.8 - ): + model: nn.Module, + prompts, + max_tokens, + tokenizer, + group_size, + is_training=False, + end_token: str = "", + temperature: float = 0.8, + batch_size: int = 1 +): if len(prompts.shape) == 1: prompts = prompts[None, :] if prompts.shape[1] == 0: return None - batch_size = prompts.shape[0] * group_size + + total_samples = prompts.shape[0] * group_size expanded_prompts = mx.repeat(prompts, group_size, axis=0) end_sequence = mx.array(tokenizer.encode(end_token)) results = [] mx.eval(expanded_prompts) + try: - for idx in range(batch_size): - current_tokens = [] + # Process in batches + for batch_start in range(0, total_samples, batch_size): + batch_end = min(batch_start + batch_size, total_samples) + batch_results = [] + if is_training: - current_input = expanded_prompts[idx] - prompt_cache = cache.make_prompt_cache(model) - logits = model(current_input[None], cache=prompt_cache)[:, -1] - mx.eval(logits, prompt_cache) - while len(current_tokens) < max_tokens: - logits_temp = logits / temperature - probs = nn.softmax(logits_temp, axis=-1) - next_token = mx.random.categorical(logits_temp) - token = next_token.item() - test_sequence = current_tokens + [token] - if (len(test_sequence) >= len(end_sequence) and - mx.array_equal( - mx.array(test_sequence[-len(end_sequence):]), - end_sequence - )): - current_tokens.append(token) - break - if token == tokenizer.eos_token_id: - break - current_tokens.append(token) - current_input = mx.array([token]) - logits = model(current_input[None], cache=prompt_cache)[:, -1] - mx.eval(current_input, logits, probs, next_token, token) - else: - generator = generate_step( - expanded_prompts[idx], - model, - max_tokens=max_tokens, - sampler=lambda x: mx.random.categorical(x / temperature) - ) - for token, _ in generator: - test_sequence = current_tokens + [token] - if (len(test_sequence) >= len(end_sequence) and - mx.array_equal( - mx.array(test_sequence[-len(end_sequence):]), - end_sequence - )): - current_tokens.append(token) - break + # Training mode with batched processing + batch_inputs = expanded_prompts[batch_start:batch_end] + prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)] + + # Initial forward pass for all prompts in batch + batch_logits = [] + for i, prompt in enumerate(batch_inputs): + logits = model(prompt[None], cache=prompt_caches[i])[:, -1] + batch_logits.append(logits) + mx.eval(batch_logits, prompt_caches) + + # Track tokens for each sequence in the batch + batch_tokens = [[] for _ in range(batch_end - batch_start)] + active_indices = list(range(batch_end - batch_start)) + + # Generate tokens until all sequences are complete + while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens: + next_active = [] + for idx in active_indices: + logits_temp = batch_logits[idx] / temperature + probs = nn.softmax(logits_temp, axis=-1) + next_token = mx.random.categorical(logits_temp) + token = next_token.item() - if token == tokenizer.eos_token_id: - break - current_tokens.append(token) - if current_tokens: - results.append(mx.array(current_tokens)) + test_sequence = batch_tokens[idx] + [token] + is_end = (len(test_sequence) >= len(end_sequence) and + mx.array_equal( + mx.array(test_sequence[-len(end_sequence):]), + end_sequence + )) + + batch_tokens[idx].append(token) + + if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens: + # This sequence is done + pass + else: + # Continue with this sequence + next_active.append(idx) + current_input = mx.array([token]) + batch_logits[idx] = model(current_input[None], cache=prompt_caches[idx])[:, -1] + + mx.eval([batch_logits[idx] for idx in next_active]) + active_indices = next_active + + # Add batch results to overall results + for tokens in batch_tokens: + if tokens: + results.append(mx.array(tokens)) + + else: + # Non-training mode with batched processing + for idx in range(batch_start, batch_end): + current_tokens = [] + generator = generate_step( + expanded_prompts[idx], + model, + max_tokens=max_tokens, + sampler=lambda x: mx.random.categorical(x / temperature) + ) + + for token, _ in generator: + test_sequence = current_tokens + [token] + if (len(test_sequence) >= len(end_sequence) and + mx.array_equal( + mx.array(test_sequence[-len(end_sequence):]), + end_sequence + )): + current_tokens.append(token) + break + + if token == tokenizer.eos_token_id: + break + current_tokens.append(token) + + if current_tokens: + results.append(mx.array(current_tokens)) + mx.metal.clear_cache() + mx.eval(results) return results @@ -151,24 +191,39 @@ def grpo_loss( ref_model, tokenizer, batch, - reward_funcs=None, - beta=0.1, - group_size=4, - epsilon=1e-4, - max_tokens=64, - temperature=1.0, - reward_weights=None, - is_validation=False + reward_funcs: Optional[List[RewardFunctions]] = None, + beta: float =0.1, + group_size: int = 4, + epsilon: float = 1e-4, + max_tokens: int = 64, + temperature: float = 0.8, + reward_weights: Optional[List[float]] = None, + is_validation: bool = False, + batch_size: int = 1 ): prompt_tokens, _, prompt_text, answer_text = batch - batch_size = len(prompt_tokens) + total_samples = len(prompt_tokens) all_completions = [] all_completion_texts = [] + batch_indices = [] # Keep track of which batch each completion belongs to - for i in range(0, batch_size, batch_size): - batch_prompts = prompt_tokens[i:i+batch_size] - prompt_tensor = mx.stack([mx.array(p) for p in batch_prompts]) + # Process in smaller batches + for i in range(0, total_samples, batch_size): + # Get actual batch size for this iteration (might be smaller for the last batch) + current_batch_size = min(batch_size, total_samples - i) + batch_prompts = prompt_tokens[i:i+current_batch_size] + + # Pad sequences to the same length + max_prompt_len = max(len(p) for p in batch_prompts) + padded_prompts = [] + + for prompt in batch_prompts: + padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt)) + padded_prompts.append(prompt + padding) + + # Convert to tensor + prompt_tensor = mx.array(padded_prompts) try: if is_validation: @@ -178,7 +233,8 @@ def grpo_loss( max_tokens, tokenizer, group_size, - temperature=temperature + temperature=temperature, + batch_size=current_batch_size ) model.train() else: @@ -189,26 +245,69 @@ def grpo_loss( tokenizer, group_size, is_training=True, - temperature=temperature + temperature=temperature, + batch_size=current_batch_size ) + if completions is not None: - for completion_ids in completions: - completion_text = tokenizer.decode(completion_ids.tolist()) - all_completions.append(completion_ids) - all_completion_texts.append(completion_text) - mx.eval(completion_ids) + for j, completion_ids in enumerate(completions): + # Calculate which prompt this completion belongs to + prompt_idx = i + (j // group_size) + if prompt_idx < total_samples: # Make sure we don't go out of bounds + batch_indices.append(prompt_idx) + completion_text = tokenizer.decode(completion_ids.tolist()) + all_completions.append(completion_ids) + all_completion_texts.append(completion_text) + mx.eval(completion_ids) except Exception as e: print(f"Generation error: {e}") continue - + mx.metal.clear_cache() + # If we didn't generate any completions, return early + if not all_completions: + print("No completions were generated. Returning zero loss.") + dummy_loss = mx.zeros(1) + dummy_metrics = { + 'total_rewards_mean': mx.zeros(1), + 'total_rewards_std': mx.zeros(1), + 'kl': mx.zeros(1) + } + return dummy_loss, mx.array(0), dummy_metrics + + # Create expanded prompts and answers based on actual generated completions expanded_answers = [] expanded_prompts = [] - for i in range(batch_size): - expanded_answers.extend([answer_text[i]] * group_size) - expanded_prompts.extend([prompt_text[i]] * group_size) - + + # Group completions by their original prompt + unique_prompt_indices = sorted(set(batch_indices)) + grouped_completions = {idx: [] for idx in unique_prompt_indices} + + for i, completion_idx in enumerate(batch_indices): + grouped_completions[completion_idx].append(i) + + # Rebuild completions in the correct order + ordered_completions = [] + ordered_completion_texts = [] + ordered_batch_indices = [] + + for prompt_idx in unique_prompt_indices: + completion_indices = grouped_completions[prompt_idx] + for idx in completion_indices: + ordered_completions.append(all_completions[idx]) + ordered_completion_texts.append(all_completion_texts[idx]) + ordered_batch_indices.append(prompt_idx) + + # Add corresponding prompt and answer + expanded_prompts.append(prompt_text[prompt_idx]) + expanded_answers.append(answer_text[prompt_idx]) + + all_completions = ordered_completions + all_completion_texts = ordered_completion_texts + batch_indices = ordered_batch_indices + + # Continue with the rest of the function max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] attention_masks = [] @@ -231,7 +330,6 @@ def grpo_loss( # Current policy probabilities token_log_probs = get_per_token_logps(model, inputs, lengths) - mx.eval(token_log_probs) if ref_model is None: @@ -282,11 +380,31 @@ def grpo_loss( rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) - # Reshape rewards and compute advantages - rewards_reshaped = rewards.reshape(batch_size, group_size) - mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) - std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) - advantages = (rewards - mean_rewards) / (std_rewards + epsilon) + # Get number of unique prompts + num_unique_prompts = len(unique_prompt_indices) + + # Reshape rewards based on actual groups + rewards_by_prompt = [[] for _ in range(num_unique_prompts)] + for i, prompt_idx in enumerate(batch_indices): + prompt_position = unique_prompt_indices.index(prompt_idx) + rewards_by_prompt[prompt_position].append(rewards[i]) + + # Calculate advantages within each group + advantages = mx.zeros_like(rewards) + for i, prompt_rewards in enumerate(rewards_by_prompt): + if len(prompt_rewards) > 1: # Only normalize if we have multiple samples + prompt_rewards = mx.array(prompt_rewards) + mean_reward = mx.mean(prompt_rewards) + std_reward = mx.std(prompt_rewards) + + # Find indices for this prompt + indices = [j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i]] + for j, idx in enumerate(indices): + advantages[idx] = (prompt_rewards[j] - mean_reward) / (std_reward + epsilon) + else: + # If only one sample, advantage is 0 + idx = batch_indices.index(unique_prompt_indices[i]) + advantages[idx] = 0.0 # Compute KL divergence using Schulman's approximator kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1 @@ -319,24 +437,36 @@ def grpo_loss( )) reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) + + + grouped_rewards_mean = mx.array([mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt]) + grouped_rewards_std = mx.array([mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt]) metrics = { 'total_rewards_mean': mx.mean(rewards), 'total_rewards_std': mx.std(rewards), - 'grouped_rewards_mean': mx.mean(rewards_reshaped), - 'grouped_rewards_std': mx.std(rewards_reshaped), + 'grouped_rewards_mean': mx.mean(grouped_rewards_mean), + 'grouped_rewards_std': mx.mean(grouped_rewards_std), 'kl': mean_kl, **reward_metrics } - if is_validation: + if is_validation and all_completion_texts: print("\n=== Validation Sample Details ===") print(f"\nšŸ“ Generation:\n{all_completion_texts[-1]}") print("\n" + "="*10 + "\n") - print(f"\nāœ… Answer:\n{answer_text[-1]}") - print("\n" + "="*10 + "\n") - print(f"\nšŸ” Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}") - print("\n" + "="*30 + "\n") + + # Make sure we have a valid index for answer_text + last_prompt_idx = batch_indices[-1] if batch_indices else 0 + if last_prompt_idx < len(answer_text): + print(f"\nāœ… Answer:\n{answer_text[last_prompt_idx]}") + print("\n" + "="*10 + "\n") + + # Only try to extract if r1_extract_xml_answer is defined + if 'r1_extract_xml_answer' in globals(): + print(f"\nšŸ” Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}") + print("\n" + "="*35 + "\n") + mx.metal.clear_cache() return loss, sequence_lengths.sum(), metrics