From 132225a0181bb356459c2890f03e0bfd23966ac0 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Sat, 1 Mar 2025 22:23:33 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/grpo_trainer.py | 126 ++++++++++++++---------------- 1 file changed, 60 insertions(+), 66 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 70ac2eda..3083e785 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -61,112 +61,101 @@ def generate_grpo( temperature: float = 0.8, batch_size: int = 1 ): - if len(prompts.shape) == 1: - prompts = prompts[None, :] - if prompts.shape[1] == 0: - return None + # Store original training state + was_training = model.training - total_samples = prompts.shape[0] * group_size - expanded_prompts = mx.repeat(prompts, group_size, axis=0) - end_sequence = mx.array(tokenizer.encode(end_token)) - results = [] - mx.eval(expanded_prompts) + # Set model to eval mode for generation + model.eval() try: + if len(prompts.shape) == 1: + prompts = prompts[None, :] + if prompts.shape[1] == 0: + return None + + total_samples = prompts.shape[0] * group_size + expanded_prompts = mx.repeat(prompts, group_size, axis=0) + end_sequence = mx.array(tokenizer.encode(end_token)) + results = [] + mx.eval(expanded_prompts) + # Process in batches for batch_start in range(0, total_samples, batch_size): batch_end = min(batch_start + batch_size, total_samples) if is_training: - # Training mode with batched processing + # Training-specific generation logic batch_inputs = expanded_prompts[batch_start:batch_end] + batch_tokens = [[] for _ in range(batch_end - batch_start)] prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)] - # Initial forward pass for all prompts in batch - batch_logits = [] + # Initial forward pass for i, prompt in enumerate(batch_inputs): logits = model(prompt[None], cache=prompt_caches[i])[:, -1] - batch_logits.append(logits) - mx.eval(batch_logits, prompt_caches) - - # Track tokens for each sequence in the batch - batch_tokens = [[] for _ in range(batch_end - batch_start)] - - # Initial token generation for all sequences in batch - for i in range(len(batch_logits)): - logits_temp = batch_logits[i] / temperature + logits_temp = logits / temperature next_token = mx.random.categorical(logits_temp) token = next_token.item() - mx.eval(logits_temp, next_token, token) batch_tokens[i].append(token) + del logits, logits_temp, next_token + + mx.eval([tokens[-1] for tokens in batch_tokens]) + mx.metal.clear_cache() + + active_indices = [i for i in range(len(batch_tokens)) if batch_tokens[i][-1] != tokenizer.eos_token_id] + + # Generate remaining tokens + for _ in range(max_tokens - 1): + if not active_indices: + break - # Check if this token already completes the sequence - if token == tokenizer.eos_token_id: - continue - else: - # Set up for next token - current_input = mx.array([token]) - batch_logits[i] = model(current_input[None], cache=prompt_caches[i])[:, -1] - - mx.eval(batch_logits) - active_indices = [i for i, tokens in enumerate(batch_tokens) if tokens[-1] != tokenizer.eos_token_id and len(tokens) < max_tokens] - - # Generate tokens until all sequences are complete - while active_indices and max(len(tokens) for tokens in batch_tokens) < max_tokens: next_active = [] for idx in active_indices: - logits_temp = batch_logits[idx] / temperature + current_input = mx.array([batch_tokens[idx][-1]]) + logits = model(current_input[None], cache=prompt_caches[idx])[:, -1] + logits_temp = logits / temperature next_token = mx.random.categorical(logits_temp) token = next_token.item() - mx.eval(logits_temp, next_token, token) batch_tokens[idx].append(token) - # Check for end sequence + # Check for end conditions + is_end = False if len(batch_tokens[idx]) >= len(end_sequence): test_sequence = batch_tokens[idx][-len(end_sequence):] - is_end = mx.array_equal( - mx.array(test_sequence), - end_sequence - ) - else: - is_end = False + is_end = mx.array_equal(mx.array(test_sequence), end_sequence) - if is_end or token == tokenizer.eos_token_id or len(batch_tokens[idx]) >= max_tokens: - # This sequence is done - pass - else: - # Continue with this sequence + if not (is_end or token == tokenizer.eos_token_id): next_active.append(idx) - current_input = mx.array([token]) - batch_logits[idx] = model(current_input[None], cache=prompt_caches[idx])[:, -1] + + del logits, logits_temp, next_token, current_input - mx.eval([batch_logits[idx] for idx in next_active]) + mx.eval([tokens[-1] for tokens in batch_tokens]) + mx.metal.clear_cache() active_indices = next_active - - # Clear caches after processing this batch + + # Clean up caches for pc in prompt_caches: del pc - # Add batch results to overall results + # Process results for tokens in batch_tokens: if tokens: - # Filter out any special tokens that might appear after the end token - if len(tokens) >= len(end_sequence): - for i in range(len(tokens) - len(end_sequence) + 1): - if mx.array_equal( - mx.array(tokens[i:i+len(end_sequence)]), - end_sequence - ): - tokens = tokens[:i+len(end_sequence)] - break + # Truncate at end token if present + for i in range(len(tokens) - len(end_sequence) + 1): + if mx.array_equal( + mx.array(tokens[i:i+len(end_sequence)]), + end_sequence + ): + tokens = tokens[:i+len(end_sequence)] + break - # Filter out EOS token if it's the last token if tokens and tokens[-1] == tokenizer.eos_token_id: tokens = tokens[:-1] - # Only add non-empty token lists if tokens: results.append(mx.array(tokens)) + + del batch_inputs, batch_tokens, prompt_caches + mx.metal.clear_cache() else: # Non-training mode with batched processing for idx in range(batch_start, batch_end): @@ -196,12 +185,17 @@ def generate_grpo( results.append(mx.array(current_tokens)) mx.metal.clear_cache() + mx.eval(results) return results except Exception as e: print(f"Generation error: {str(e)}") return None + + finally: + # Don't restore training mode - let the caller handle it + pass def get_per_token_logps(model: nn.Module, inputs, lengths):