diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 2e53fb0d..f7734eb6 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -453,4 +453,177 @@ def main(): if __name__ == "__main__": - main() \ No newline at end of file + main() + + + + + + + +def compute_grpo_loss_and_grad( + model, + ref_model, + completion_tensors, + prompt_texts, + answer_texts, + beta=0.1, + epsilon=1e-4, + reward_funcs=None, + reward_weights=None +): + """ + Compute GRPO loss and gradients using pre-generated completions. + + Args: + model: The policy model + ref_model: The reference model + completion_tensors: List of tensors containing generated completions + prompt_texts: List of prompt texts + answer_texts: List of answer texts + beta: KL penalty coefficient + epsilon: Numerical stability constant + reward_funcs: List of reward functions + reward_weights: Optional weights for reward functions + """ + # Ensure model is in training mode for gradient computation + model.train() + + # Get completion texts for reward calculation + completion_texts = [tokenizer.decode(comp.tolist()) for comp in completion_tensors] + + # Prepare inputs for loss computation + max_length = max(tensor.shape[0] for tensor in completion_tensors) + padded_completions = [] + attention_masks = [] + + for completion_tensor in completion_tensors: + padding_length = max_length - completion_tensor.shape[0] + if padding_length > 0: + padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype) + padded_ids = mx.concatenate([completion_tensor, padding]) + mask = mx.concatenate( + [mx.ones_like(completion_tensor), mx.zeros_like(padding)] + ) + else: + padded_ids = completion_tensor + mask = mx.ones_like(completion_tensor) + padded_completions.append(padded_ids) + attention_masks.append(mask) + + inputs = mx.stack(padded_completions) + attention_mask = mx.stack(attention_masks) + lengths = attention_mask.sum(axis=1) + + # Compute log probabilities for both models + token_log_probs = get_per_token_logps(model, inputs, lengths) + + if ref_model is None: + ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in token_log_probs] + else: + ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) + ref_token_log_probs = [mx.stop_gradient(tlp) for tlp in ref_token_log_probs] + + # Pad log probabilities to same length + max_len = max(x.shape[0] for x in token_log_probs) + padded_log_probs = [] + padded_ref_log_probs = [] + + for i in range(len(token_log_probs)): + seq_len = token_log_probs[i].shape[0] + padding = mx.zeros((max_len - seq_len,)) + + padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) + padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) + + token_log_probs = mx.stack(padded_log_probs) + ref_token_log_probs = mx.stack(padded_ref_log_probs) + + # Calculate rewards + all_func_rewards = [] + for reward_func in reward_funcs: + func_rewards = mx.array( + reward_func( + prompts=prompt_texts, + completions=completion_texts, + answer=answer_texts, + ) + ) + all_func_rewards.append(func_rewards) + + # Stack rewards and apply weights + rewards = mx.stack(all_func_rewards, axis=1) + if reward_weights is not None: + if len(reward_weights) != len(reward_funcs): + raise ValueError( + f"Number of reward weights ({len(reward_weights)}) must match number of reward " + f"functions ({len(reward_funcs)})" + ) + reward_weights = mx.array(reward_weights, dtype=mx.float32) + else: + reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32) + + rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) + + # Group rewards by prompt (assuming completions are grouped by prompt) + group_size = len(completion_tensors) // len(prompt_texts) + if len(completion_tensors) % len(prompt_texts) != 0: + raise ValueError("Number of completions must be divisible by number of prompts") + + rewards_by_group = [] + for i in range(0, len(rewards), group_size): + rewards_by_group.append(rewards[i:i+group_size]) + + # Calculate advantages + advantages = mx.zeros_like(rewards) + for i, group_rewards in enumerate(rewards_by_group): + if len(group_rewards) > 1: # Only normalize if we have multiple samples + mean_reward = mx.mean(group_rewards) + std_reward = mx.std(group_rewards) + + for j in range(group_size): + idx = i * group_size + j + advantages[idx] = (group_rewards[j] - mean_reward) / (std_reward + epsilon) + else: + # If only one sample, advantage is 0 + advantages[i * group_size] = 0.0 + + # Compute KL divergence + kl_div = ( + mx.exp(ref_token_log_probs - token_log_probs) + - (ref_token_log_probs - token_log_probs) + - 1 + ) + + # Create mask for valid tokens + length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) + + # Compute policy ratio + policy_ratio = mx.exp(token_log_probs - ref_token_log_probs) + + # Compute per-token loss + per_token_loss = -( + (policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask + ) + + # Average over tokens + sequence_sums = per_token_loss.sum(axis=1) + sequence_lengths = length_mask.sum(axis=1) + loss = (sequence_sums / sequence_lengths).mean() + + # Calculate metrics for reporting + mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() + + metrics = { + "total_rewards_mean": mx.mean(rewards), + "total_rewards_std": mx.std(rewards), + "kl": mean_kl, + } + + for i, reward_func in enumerate(reward_funcs): + func_name = reward_func.__name__ + func_rewards = all_func_rewards[i] + metrics[f"{func_name}_mean"] = mx.mean(func_rewards) + metrics[f"{func_name}_std"] = mx.std(func_rewards) + + return loss, sequence_lengths.sum(), metrics \ No newline at end of file diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index feb27737..ec080bca 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -297,6 +297,79 @@ def get_per_token_logps(model: nn.Module, inputs, lengths): return per_token_logps +def generate_without_gradients( + model: nn.Module, + tokenizer, + prompt_tokens, + max_tokens: int, + group_size: int, + temperature: float = 0.8, + batch_size: int = 1 +): + """Generate completions without tracking gradients""" + + # Store original state + was_training = model.training + + # Force eval mode + model.eval() + + # Prepare prompts + total_samples = len(prompt_tokens) + all_completions = [] + all_completion_texts = [] + batch_indices = [] + + # Process in smaller batches + for i in range(0, total_samples, batch_size): + current_batch_size = min(batch_size, total_samples - i) + batch_prompts = prompt_tokens[i : i + current_batch_size] + + # Pad sequences to the same length + max_prompt_len = max(len(p) for p in batch_prompts) + padded_prompts = [] + + for prompt in batch_prompts: + padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt)) + padded_prompts.append(prompt + padding) + + # Convert to tensor and explicitly stop gradient + prompt_tensor = mx.stop_gradient(mx.array(padded_prompts)) + + try: + completions = generate_grpo( + model, + prompt_tensor, + max_tokens, + tokenizer, + group_size, + temperature=temperature, + batch_size=current_batch_size, + ) + + if completions is not None: + for j, completion_ids in enumerate(completions): + prompt_idx = i + (j // group_size) + + if prompt_idx < total_samples: + batch_indices.append(prompt_idx) + completion_text = tokenizer.decode(completion_ids.tolist()) + all_completions.append(completion_ids) + all_completion_texts.append(completion_text) + mx.eval(completion_ids) + except Exception as e: + print(f"Generation error: {e}") + continue + + # Restore original state + if was_training: + model.train() + + mx.metal.clear_cache() + + return all_completions, all_completion_texts, batch_indices + + def grpo_loss( model, ref_model, @@ -313,70 +386,17 @@ def grpo_loss( is_validation: bool = False ): prompt_tokens, _, prompt_text, answer_text = batch - total_samples = len(prompt_tokens) - all_completions = [] - all_completion_texts = [] - batch_indices = [] # Keep track of which batch each completion belongs to - - # Store original training state - was_training = model.training - print(f"Was model in training mode: {was_training}") - - # Set model to eval mode for generation - model.eval() - print(f"Is model now in training mode: {model.training}") - - # Process in smaller batches - for i in range(0, total_samples, batch_size): - # Get actual batch size for this iteration (might be smaller for the last batch) - current_batch_size = min(batch_size, total_samples - i) - batch_prompts = prompt_tokens[i : i + current_batch_size] - - # Pad sequences to the same length - max_prompt_len = max(len(p) for p in batch_prompts) - padded_prompts = [] - - for prompt in batch_prompts: - padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt)) - padded_prompts.append(prompt + padding) - - # Convert to tensor - prompt_tensor = mx.array(padded_prompts) - prompt_tensor = mx.stop_gradient(prompt_tensor) # Explicitly stop gradient on input - - try: - mx.metal.clear_cache() - completions = generate_grpo( - model, - prompt_tensor, - max_tokens, - tokenizer, - group_size, - temperature=temperature, - batch_size=current_batch_size, - ) - - if completions is not None: - for j, completion_ids in enumerate(completions): - # Calculate which prompt this completion belongs to - prompt_idx = i + (j // group_size) - - if prompt_idx < total_samples: # Make sure we don't go out of bounds - batch_indices.append(prompt_idx) - completion_text = tokenizer.decode(completion_ids.tolist()) - all_completions.append(completion_ids) - all_completion_texts.append(completion_text) - mx.eval(completion_ids) - except Exception as e: - print(f"Generation error: {e}") - print(f"Is model in training mode after generation: {model.training}") - continue - - # Restore original training state if we're not in validation mode - if was_training: - model.train() - mx.metal.clear_cache() + # Generate completions without tracking gradients + all_completions, all_completion_texts, batch_indices = generate_without_gradients( + model=model, + tokenizer=tokenizer, + prompt_tokens=prompt_tokens, + max_tokens=max_tokens, + group_size=group_size, + temperature=temperature, + batch_size=batch_size + ) # If we didn't generate any completions, return early if not all_completions: @@ -415,25 +435,30 @@ def grpo_loss( all_completion_texts = ordered_completion_texts batch_indices = ordered_batch_indices - # Continue with the rest of the function + # Create new input tensors for the model to compute logits with gradient tracking max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] attention_masks = [] for completion_ids in all_completions: - padding_length = max_length - completion_ids.shape[0] + # Convert the pre-generated completion to a regular tensor (not stop_gradient) + # This allows gradients to flow during the loss computation phase + completion_tensor = mx.array(completion_ids.tolist()) + + padding_length = max_length - completion_tensor.shape[0] if padding_length > 0: - padding = mx.zeros((padding_length,), dtype=completion_ids.dtype) - padded_ids = mx.concatenate([completion_ids, padding]) + padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype) + padded_ids = mx.concatenate([completion_tensor, padding]) mask = mx.concatenate( - [mx.ones_like(completion_ids), mx.zeros_like(padding)] + [mx.ones_like(completion_tensor), mx.zeros_like(padding)] ) else: - padded_ids = completion_ids - mask = mx.ones_like(completion_ids) + padded_ids = completion_tensor + mask = mx.ones_like(completion_tensor) padded_completions.append(padded_ids) attention_masks.append(mask) + # Rest of the function remains the same inputs = mx.stack(padded_completions) attention_mask = mx.stack(attention_masks) lengths = attention_mask.sum(axis=1) @@ -721,7 +746,6 @@ def evaluate_grpo( ref_model=ref_model, temperature=temperature, max_tokens=max_tokens, - is_validation=True ) all_losses += losses * toks