From 05d921b788f0b269a0f2a35c17499f626e021599 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Mon, 3 Feb 2025 19:37:05 +0100 Subject: [PATCH] optims --- llms/mlx_lm/tuner/grpo_trainer.py | 324 +++++++++++++----------------- 1 file changed, 138 insertions(+), 186 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 16125611..f6dfc830 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -35,68 +35,50 @@ class GRPOTrainingArgs(TrainingArgs): ) -def generate_for_grpo( - model, - prompt, - max_tokens, - tokenizer, - temperature=1.0 -): - try: +def generate_grpo(model, prompt, max_tokens, tokenizer, temperature=1.0): + model.eval() + if len(prompt.shape) == 1: + prompt = prompt[None, :] + + generated = [] + current_prompt = prompt[0] + + for _ in range(max_tokens): + current_batch = current_prompt[None, :] + logits = model(current_batch) + token_logits = logits[0, -1] - # Ensure prompt is the right shape - if len(prompt.shape) == 1: - prompt = prompt[None, :] - - # Initialize generation - generated = [] - current_prompt = prompt[0] - - for step in range(max_tokens): - try: - # Get model output with explicit shape checking - current_batch = current_prompt[None, :] - - logits = model(current_batch) - - # Ensure we have the last token logits - token_logits = logits[0, -1] - - # Apply temperature and get probabilities - if temperature > 0: - token_logits = token_logits / temperature - probs = mx.softmax(token_logits) - - # Sample the next token - next_token = mx.random.categorical(probs[None, :]) - next_token = next_token[0] - - # Force evaluation to catch any issues - mx.eval(next_token) - token_value = next_token.item() - - # Add to generated sequence - generated.append(next_token) - current_prompt = mx.concatenate([current_prompt, next_token[None]]) - - if token_value == tokenizer.eos_token_id: - break - - except Exception as e: - raise - - if not generated: - return prompt[0] + if temperature > 0: + token_logits = token_logits / temperature - try: - result = mx.concatenate([prompt[0], mx.stack(generated)]) - mx.eval(result) - return result - except Exception as e: - raise + probs = mx.softmax(token_logits) + next_token = mx.random.categorical(probs[None, :]) + next_token = next_token[0] + mx.eval(next_token) + + token_value = next_token.item() + generated.append(next_token) + + # Clear intermediate tensors + del logits, token_logits, probs + mx.metal.clear_cache() + + current_prompt = mx.concatenate([current_prompt, next_token[None]]) + if token_value == tokenizer.eos_token_id: + break - except Exception as e: - raise + if not generated: + return prompt[0] + + result = mx.concatenate([prompt[0], mx.stack(generated)]) + mx.eval(result) + model.train() + + # Clear generated tokens + del generated + mx.metal.clear_cache() + + return result def r1_extract_xml_answer(text: str) -> str: @@ -191,67 +173,46 @@ def grpo_loss( group_size=4, epsilon=1e-4, ref_model=None, - max_tokens=128, + max_tokens=64, temperature=1.0 ): - """Modified GRPO loss function with better error handling""" prompt_tokens, answer_tokens, prompt_text, answer_text = batch batch_size = len(prompt_tokens) - # Generate completions for each prompt + # Generation logic remains the same all_completions = [] all_completion_texts = [] for prompt in prompt_tokens: prompt_tensor = mx.array(prompt) - prompt_completions = [] - prompt_completion_texts = [] - # Generate group_size completions for each prompt for _ in range(group_size): try: - completion_ids = generate_for_grpo( - model, - prompt_tensor, - max_tokens, - tokenizer=tokenizer, - temperature=temperature - ) - - # Verify completion_ids is not None + completion_ids = generate_grpo(model, prompt_tensor, max_tokens, tokenizer, temperature) if completion_ids is None: - print("Warning: generate_for_grpo returned None") - break + continue completion_text = tokenizer.decode(completion_ids.tolist()) + all_completions.append(completion_ids) + all_completion_texts.append(completion_text) - prompt_completions.append(completion_ids) - prompt_completion_texts.append(completion_text) + del completion_ids + mx.metal.clear_cache() except Exception as e: - print(f"Error in completion generation: {str(e)}") - # Fallback to using original prompt - prompt_completions.append(prompt_tensor) - prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist())) + print(f"Generation error: {e}") + continue - all_completions.extend(prompt_completions) - all_completion_texts.extend(prompt_completion_texts) + del prompt_tensor + mx.metal.clear_cache() - # Verify we have the expected number of completions - assert len(all_completions) == batch_size * group_size - assert len(all_completion_texts) == batch_size * group_size - - # Expand answer_text and prompt_text to match completion groups + # Prepare inputs expanded_answers = [] expanded_prompts = [] for i in range(batch_size): expanded_answers.extend([answer_text[i]] * group_size) expanded_prompts.extend([prompt_text[i]] * group_size) - - # Verify we have the expected number of completions - assert len(all_completions) == batch_size * group_size - assert len(all_completion_texts) == batch_size * group_size - + max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] attention_masks = [] @@ -267,32 +228,37 @@ def grpo_loss( mask = mx.ones_like(completion_ids) padded_completions.append(padded_ids) attention_masks.append(mask) + + del completion_ids + if padding_length > 0: + del padding + del mask + mx.metal.clear_cache() inputs = mx.stack(padded_completions) attention_mask = mx.stack(attention_masks) lengths = attention_mask.sum(axis=1) - # Get logits from current model + del padded_completions, attention_masks + mx.metal.clear_cache() + + # Get logits and compute log probabilities logits = model(inputs).astype(mx.float32) - - # Calculate log probabilities log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1) - - # Prepare targets targets = inputs[:, 1:] - # Gather actual token probabilities + # Current policy probabilities token_log_probs = mx.take_along_axis( log_probs, targets.reshape(*targets.shape, 1), axis=-1 ).squeeze(-1) - # Get reference model log probabilities + # Reference policy probabilities if ref_model is not None: ref_logits = ref_model(inputs).astype(mx.float32) else: - ref_logits = model(inputs).astype(mx.float32) + ref_logits = mx.array(logits) ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1) ref_token_log_probs = mx.take_along_axis( @@ -301,124 +267,107 @@ def grpo_loss( axis=-1 ).squeeze(-1) - # Compute KL divergence - kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1) - - # Calculate combined rewards from all reward functions + # Calculate rewards and advantages rewards = mx.zeros((len(all_completions),)) for reward_func in reward_funcs: func_rewards = mx.array(reward_func( - prompts=prompt_text, - completions=all_completion_texts, - answer=answer_text - )) - rewards += func_rewards - - # Normalize rewards if using multiple reward functions - if len(reward_funcs) > 1: - rewards /= len(reward_funcs) - - # Compute grouped-wise rewards - grouped_rewards = rewards.reshape(batch_size, group_size) - mean_grouped_rewards = mx.mean(grouped_rewards, axis=1) - std_grouped_rewards = mx.std(grouped_rewards, axis=1) - - # Normalize rewards to compute advantages - mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) - std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) - advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon) - - # Create length mask for the shifted sequence - length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) - - # Calculate policy gradient loss - per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1) - per_token_loss = -(per_token_loss - beta * kl_div) - - # Normalize loss properly per sequence - sequence_sums = (per_token_loss * length_mask).sum(axis=1) - sequence_lengths = length_mask.sum(axis=1) - loss = (sequence_sums / sequence_lengths).mean() - - # Calculate mean KL divergence - mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() - - # Collect metrics for each reward function separately - reward_metrics = {} - for i, reward_func in enumerate(reward_funcs): - func_rewards = mx.array(reward_func( - prompts=prompt_text, + prompts=prompt_text, + completions=all_completion_texts, + answer=answer_text + )) + rewards += func_rewards + + if len(reward_funcs) > 1: + rewards /= len(reward_funcs) + + # Reshape rewards and compute advantages following GRPO formula + rewards_reshaped = rewards.reshape(batch_size, group_size) + mean_rewards = mx.broadcast_to(mx.mean(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) + std_rewards = mx.broadcast_to(mx.std(rewards_reshaped, axis=1)[:, None], (rewards_reshaped.shape[0], group_size)).reshape(-1) + advantages = (rewards - mean_rewards) / (std_rewards + epsilon) + + # Compute KL divergence using Schulman's approximator + kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1 + + # Create mask for valid tokens + length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) + + # Compute policy ratio + policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) + + # Compute per-token loss following GRPO formula + per_token_loss = -(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) + + # Average over tokens and sequences + sequence_sums = (per_token_loss * length_mask).sum(axis=1) + sequence_lengths = length_mask.sum(axis=1) + loss = (sequence_sums / sequence_lengths).mean() + + # Calculate mean KL divergence for metrics + mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() + + # Collect reward metrics + reward_metrics = {} + for i, reward_func in enumerate(reward_funcs): + func_rewards = mx.array(reward_func( + prompts=prompt_text, completions=all_completion_texts, answer=answer_text )) - # func_grouped_rewards = func_rewards.reshape(batch_size, group_size) reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) - + + # Clean up + del all_completions + mx.metal.clear_cache() + metrics = { 'total_rewards_mean': mx.mean(rewards), 'total_rewards_std': mx.std(rewards), - 'grouped_rewards_mean': mx.mean(grouped_rewards), - 'grouped_rewards_std': mx.std(grouped_rewards), + 'grouped_rewards_mean': mx.mean(rewards_reshaped), + 'grouped_rewards_std': mx.std(rewards_reshaped), 'kl': mean_kl, **reward_metrics } - + return loss, sequence_lengths.sum(), metrics def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - """ - Creates batches from dataset entries for GRPO training. - - Args: - dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples - tokenizer: Tokenizer for processing inputs - batch_size: Size of each batch - max_seq_length: Maximum sequence length - train: Whether this is for training - - Yields: - Tuple containing: - - prompts_tokens: List of token sequences for current batch - - answers_tokens: List of token sequences - - prompts_text: List of prompt strings - - answers_text: List of answer strings - """ - # Verify dataset format + """Memory-optimized version of iterate_grpo_batches""" if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") - # Sort by combined length of prompt + answer tokens - idx = sorted(range(len(dataset)), - key=lambda i: len(dataset[i][0]) + len(dataset[i][1])) - + # Sort by length but use generator to avoid keeping full sorted list in memory + def length_key(i): + return len(dataset[i][0]) + len(dataset[i][1]) + + idx = sorted(range(len(dataset)), key=length_key) + if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size} " f"examples but only has {len(dataset)}." ) - # Handle distributed training step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") - # Create batch indices - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] + # Use generator for batch indices + def batch_index_generator(): + for i in range(0, len(idx) - batch_size + 1, batch_size): + yield idx[i : i + batch_size : step] while True: - # Shuffle batch indices if training - indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) + indices = ( + np.random.permutation(list(batch_index_generator())) if train + else batch_index_generator() + ) - for i in indices: - # Get current batch - current_batch = [dataset[j] for j in batch_idx[i]] + for batch_idx in indices: + current_batch = [dataset[j] for j in batch_idx] - # Extract all components prompts_tokens = [item[0] for item in current_batch] answers_tokens = [item[1] for item in current_batch] prompts_text = [item[2] for item in current_batch] @@ -553,7 +502,8 @@ def train_grpo( beta=args.beta, group_size=args.group_size, epsilon=args.epsilon, - ref_model=ref_model + ref_model=ref_model, + max_tokens=args.max_seq_length, ) # All reduce the gradients if running in distributed mode @@ -649,8 +599,10 @@ def train_grpo( losses += loss n_tokens += toks steps += 1 + for k, v in metrics.items(): accumulated_metrics[k] += v + mx.eval(state, losses, n_tokens) if it % args.steps_per_report == 0 or it == args.iters: