diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index ec080bca..d65cf4b6 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -57,228 +57,6 @@ class GRPOTrainingArgs(TrainingArgs): ) -def generate_step( - prompt: mx.array, - model: nn.Module, - *, - max_tokens: int = 256, - sampler: Optional[Callable[mx.array, mx.array]] = None, - logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, - max_kv_size: Optional[int] = None, - prompt_cache: Optional[Any] = None, - prefill_step_size: int = 512, - prompt_progress_callback: Optional[Callable[int, int]] = None, -) -> Generator[Tuple[mx.array, mx.array], None, None]: - """ - A generator producing token ids based on the given prompt from the model. - - Args: - prompt (mx.array): The input prompt. - model (nn.Module): The model to use for generation. - max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite - generator. Default: ``256``. - sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a - token from a vector of log probabilities. Default: ``None``. - logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): - A list of functions that take tokens and logits and return the processed - logits. Default: ``None``. - max_kv_size (int, optional): Maximum size of the key-value cache. Old - entries (except the first 4 tokens) will be overwritten. - prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if - provided, the cache will be updated in place. - prefill_step_size (int): Step size for processing the prompt. - kv_bits (int, optional): Number of bits to use for KV cache quantization. - None implies no cache quantization. Default: ``None``. - kv_group_size (int): Group size for KV cache quantization. Default: ``64``. - quantized_kv_start (int): Step to begin using a quantized KV cache. - when ``kv_bits`` is non-None. Default: ``0``. - prompt_prorgress_callback (Callable[int, int]): A call-back which takes the - prompt tokens processed so far and the total number of prompt tokens. - - Yields: - Tuple[mx.array, mx.array]: One token and a vector of log probabilities. - """ - - y = prompt - tokens = None - - # Create the KV cache for generation - if prompt_cache is None: - prompt_cache = cache.make_prompt_cache( - model, - max_kv_size=max_kv_size, - ) - elif len(prompt_cache) != len(model.layers): - raise ValueError("Wrong number of layers in the prompt cache.") - - prompt_progress_callback = prompt_progress_callback or (lambda *_: None) - - sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) - - def _step(y): - with mx.stream(generation_stream): - logits = model(y[None], cache=prompt_cache) - logits = logits[:, -1, :] - - if logits_processors: - nonlocal tokens - tokens = mx.concat([tokens, y]) if tokens is not None else y - - for processor in logits_processors: - logits = processor(tokens, logits) - - logprobs = logits - mx.logsumexp(logits, keepdims=True) - y = sampler(logprobs) - return mx.stop_gradient(y), mx.stop_gradient(logprobs.squeeze(0)) - - with mx.stream(generation_stream): - total_prompt_tokens = y.size - prompt_processed_tokens = 0 - while y.size > prefill_step_size: - model(y[:prefill_step_size][None], cache=prompt_cache) - mx.eval([c.state for c in prompt_cache]) - prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) - prompt_processed_tokens += prefill_step_size - y = y[prefill_step_size:] - mx.metal.clear_cache() - - y, logprobs = _step(y) - - mx.eval(y, logprobs) - n = 0 - while True: - if n != max_tokens: - next_y, next_logprobs = _step(y) - mx.eval(next_y, next_logprobs) - if n == 0: - mx.eval(y) - prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) - if n == max_tokens: - break - yield y.item(), logprobs - if n % 256 == 0: - mx.metal.clear_cache() - y, logprobs = next_y, next_logprobs - n += 1 - - -def generate_grpo( - model: nn.Module, - prompts, - max_tokens, - tokenizer, - group_size, - end_token: str = "", - temperature: float = 0.8, - batch_size: int = 1, -): - try: - import time - - start_time = time.time() - - if len(prompts.shape) == 1: - prompts = prompts[None, :] - if prompts.shape[1] == 0: - return None - - total_samples = prompts.shape[0] * group_size - expanded_prompts = mx.repeat(prompts, group_size, axis=0) - end_sequence = mx.array(tokenizer.encode(end_token)) - results = [] - mx.eval(expanded_prompts, results) - - print(f"Setup time: {time.time() - start_time:.2f}s") - print(f"Generating {total_samples} samples with max_tokens={max_tokens}") - - total_tokens_generated = 0 - generation_start_time = time.time() - - # Process in batches - for batch_start in range(0, total_samples, batch_size): - batch_end = min(batch_start + batch_size, total_samples) - batch_time = time.time() - print( - f"Starting batch {batch_start//batch_size + 1}/{(total_samples + batch_size - 1)//batch_size}: samples {batch_start}-{batch_end-1}" - ) - - # Custom sampler function that handles temperature - def temp_sampler(logits): - return mx.random.categorical(logits / temperature) - - # Batched processing - for idx in range(batch_start, batch_end): - sample_start_time = time.time() - current_tokens = [] - prompt_cache = cache.make_prompt_cache(model) - - # The generate_step function yields one token at a time - # We'll collect tokens until we hit max_tokens or a stopping condition - for i, (token, _) in enumerate( - generate_step( - expanded_prompts[idx], - model, - max_tokens=max_tokens, # This is the maximum number of steps - sampler=temp_sampler, - prompt_cache=prompt_cache, - ) - ): - # Check for EOS token - if token == tokenizer.eos_token_id: - break - - current_tokens.append(token) - - print(token) - - # Check for end token - if len(current_tokens) >= len(end_sequence) and mx.array_equal( - mx.array(current_tokens[-len(end_sequence) :]), end_sequence - ): - break - - # Check if we've reached the maximum number of tokens - if i >= max_tokens - 1: - break - - mx.metal.clear_cache() - mx.eval(current_tokens) - - if current_tokens: - results.append(mx.array(current_tokens)) - total_tokens_generated += len(current_tokens) - - sample_time = time.time() - sample_start_time - tokens_per_second = ( - len(current_tokens) / sample_time if sample_time > 0 else 0 - ) - print( - f" Sample {idx}: Generated {len(current_tokens)} tokens in {sample_time:.2f}s ({tokens_per_second:.2f} tokens/sec)" - ) - - batch_time = time.time() - batch_time - print(f"Batch completed in {batch_time:.2f}s") - mx.metal.clear_cache() - - generation_time = time.time() - generation_start_time - avg_tokens_per_second = ( - total_tokens_generated / generation_time if generation_time > 0 else 0 - ) - - print( - f"Generation complete: {total_tokens_generated} tokens in {generation_time:.2f}s" - ) - print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec") - - results = [mx.stop_gradient(r) for r in results] - mx.eval(results) - return results - - except Exception as e: - print(f"Generation error: {str(e)}") - return None - - def get_per_token_logps(model: nn.Module, inputs, lengths): logits = model(inputs).astype(mx.float16) logits = logits[:, :-1, :] @@ -297,75 +75,124 @@ def get_per_token_logps(model: nn.Module, inputs, lengths): return per_token_logps -def generate_without_gradients( +def generate_step( + prompt: mx.array, + model: nn.Module, + max_tokens: int = 256, + sampler: Optional[Callable] = None, + logits_processors: Optional[List[Callable]] = None, + max_kv_size: Optional[int] = None, + prompt_cache: Optional[Any] = None, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + tokens = None + y = prompt + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache(model, max_kv_size=max_kv_size) + def _step(y): + with mx.stream(generation_stream): + logits = model(y[None], cache=prompt_cache) + logits = logits[:, -1, :] + if logits_processors: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y + for processor in logits_processors: + logits = processor(tokens, logits) + logprobs = logits - mx.logsumexp(logits, keepdims=True) + next_token = sampler(logprobs) + return mx.stop_gradient(next_token), mx.stop_gradient(logprobs.squeeze(0)) + try: + with mx.stream(generation_stream): + y, logprobs = _step(y) + mx.eval(y, logprobs) + for n in range(max_tokens): + yield y.item(), logprobs + next_y, next_logprobs = _step(y) + mx.eval(next_y, next_logprobs) + y, logprobs = next_y, next_logprobs + if (n + 1) % 32 == 0: + mx.metal.clear_cache() + finally: + mx.metal.clear_cache() + + +def generate_grpo( model: nn.Module, tokenizer, prompt_tokens, max_tokens: int, group_size: int, + end_token: str = "", temperature: float = 0.8, - batch_size: int = 1 + batch_size: int = 1, ): - """Generate completions without tracking gradients""" - - # Store original state - was_training = model.training - - # Force eval mode - model.eval() - - # Prepare prompts - total_samples = len(prompt_tokens) - all_completions = [] - all_completion_texts = [] - batch_indices = [] - - # Process in smaller batches - for i in range(0, total_samples, batch_size): - current_batch_size = min(batch_size, total_samples - i) - batch_prompts = prompt_tokens[i : i + current_batch_size] - - # Pad sequences to the same length - max_prompt_len = max(len(p) for p in batch_prompts) - padded_prompts = [] - - for prompt in batch_prompts: - padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt)) - padded_prompts.append(prompt + padding) - - # Convert to tensor and explicitly stop gradient - prompt_tensor = mx.stop_gradient(mx.array(padded_prompts)) - - try: - completions = generate_grpo( - model, - prompt_tensor, - max_tokens, - tokenizer, - group_size, - temperature=temperature, - batch_size=current_batch_size, - ) - - if completions is not None: - for j, completion_ids in enumerate(completions): - prompt_idx = i + (j // group_size) + try: + end_sequence = mx.array(tokenizer.encode(end_token)) + total_samples = len(prompt_tokens) + all_completions = [] + all_completion_texts = [] + batch_indices = [] + + def temp_sampler(logits): + return mx.random.categorical(logits / temperature) + + for i in range(0, total_samples, batch_size): + current_batch_size = min(batch_size, total_samples - i) + batch_prompts = prompt_tokens[i : i + current_batch_size] + + max_prompt_len = max(len(p) for p in batch_prompts) + padded_prompts = [] + for prompt in batch_prompts: + padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt)) + padded_prompts.append(prompt + padding) + + prompt_tensor = mx.stop_gradient(mx.array(padded_prompts)) + + if len(prompt_tensor.shape) == 1: + prompt_tensor = prompt_tensor[None, :] + if prompt_tensor.shape[1] == 0: + continue + + expanded_prompts = mx.repeat(prompt_tensor, group_size, axis=0) + batch_results = [] + + total_prompt_samples = expanded_prompts.shape[0] + for prompt_idx in range(total_prompt_samples): + current_tokens = [] + prompt_cache = cache.make_prompt_cache(model) + + for token, _ in generate_step( + expanded_prompts[prompt_idx], + model, + max_tokens=max_tokens, + sampler=temp_sampler, + prompt_cache=prompt_cache, + ): + if token == tokenizer.eos_token_id: + break + + current_tokens.append(token) + if len(current_tokens) >= len(end_sequence) and mx.array_equal( + mx.array(current_tokens[-len(end_sequence):]), end_sequence + ): + break + + if current_tokens: + batch_results.append(mx.array(current_tokens)) + + if batch_results: + for j, completion_ids in enumerate(batch_results): + prompt_idx = i + (j // group_size) if prompt_idx < total_samples: batch_indices.append(prompt_idx) completion_text = tokenizer.decode(completion_ids.tolist()) - all_completions.append(completion_ids) + all_completions.append(mx.stop_gradient(completion_ids)) all_completion_texts.append(completion_text) - mx.eval(completion_ids) - except Exception as e: - print(f"Generation error: {e}") - continue + + mx.metal.clear_cache() - # Restore original state - if was_training: - model.train() - - mx.metal.clear_cache() + finally: + mx.metal.clear_cache() return all_completions, all_completion_texts, batch_indices @@ -375,6 +202,9 @@ def grpo_loss( ref_model, tokenizer, batch, + completions=None, + completion_texts=None, + batch_indices=None, reward_funcs: Optional[List[RewardFunctions]] = None, beta: float = 0.1, group_size: int = 4, @@ -386,36 +216,36 @@ def grpo_loss( is_validation: bool = False ): prompt_tokens, _, prompt_text, answer_text = batch + + if completions is not None and completion_texts is not None and batch_indices is not None: + all_completions = completions + all_completion_texts = completion_texts + batch_indices = batch_indices + else: + all_completions, all_completion_texts, batch_indices = generate_grpo( + model=model, + tokenizer=tokenizer, + prompt_tokens=prompt_tokens, + max_tokens=max_tokens, + group_size=group_size, + temperature=temperature, + batch_size=batch_size + ) - # Generate completions without tracking gradients - all_completions, all_completion_texts, batch_indices = generate_without_gradients( - model=model, - tokenizer=tokenizer, - prompt_tokens=prompt_tokens, - max_tokens=max_tokens, - group_size=group_size, - temperature=temperature, - batch_size=batch_size - ) - - # If we didn't generate any completions, return early if not all_completions: raise ValueError( "No completions were generated. Please check your model and inputs." ) - # Create expanded prompts and answers based on actual generated completions expanded_answers = [] expanded_prompts = [] - # Group completions by their original prompt unique_prompt_indices = sorted(set(batch_indices)) grouped_completions = {idx: [] for idx in unique_prompt_indices} for i, completion_idx in enumerate(batch_indices): grouped_completions[completion_idx].append(i) - # Rebuild completions in the correct order ordered_completions = [] ordered_completion_texts = [] ordered_batch_indices = [] @@ -426,8 +256,6 @@ def grpo_loss( ordered_completions.append(all_completions[idx]) ordered_completion_texts.append(all_completion_texts[idx]) ordered_batch_indices.append(prompt_idx) - - # Add corresponding prompt and answer expanded_prompts.append(prompt_text[prompt_idx]) expanded_answers.append(answer_text[prompt_idx]) @@ -435,14 +263,11 @@ def grpo_loss( all_completion_texts = ordered_completion_texts batch_indices = ordered_batch_indices - # Create new input tensors for the model to compute logits with gradient tracking max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] attention_masks = [] for completion_ids in all_completions: - # Convert the pre-generated completion to a regular tensor (not stop_gradient) - # This allows gradients to flow during the loss computation phase completion_tensor = mx.array(completion_ids.tolist()) padding_length = max_length - completion_tensor.shape[0] @@ -458,12 +283,10 @@ def grpo_loss( padded_completions.append(padded_ids) attention_masks.append(mask) - # Rest of the function remains the same inputs = mx.stack(padded_completions) attention_mask = mx.stack(attention_masks) lengths = attention_mask.sum(axis=1) - # Current policy probabilities token_log_probs = get_per_token_logps(model, inputs, lengths) mx.eval(token_log_probs) @@ -487,10 +310,8 @@ def grpo_loss( token_log_probs = mx.stack(padded_log_probs) ref_token_log_probs = mx.stack(padded_ref_log_probs) - # Create array to store rewards from each function all_func_rewards = [] - # Collect rewards from each function separately for reward_func in reward_funcs: func_rewards = mx.array( reward_func( @@ -501,10 +322,8 @@ def grpo_loss( ) all_func_rewards.append(func_rewards) - # Stack rewards to shape (num_samples, num_funcs) rewards = mx.stack(all_func_rewards, axis=1) - # Apply weights and sum if reward_weights is not None: if len(reward_weights) != len(reward_funcs): raise ValueError( @@ -517,24 +336,19 @@ def grpo_loss( rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) - # Get number of unique prompts num_unique_prompts = len(unique_prompt_indices) - # Reshape rewards based on actual groups rewards_by_prompt = [[] for _ in range(num_unique_prompts)] for i, prompt_idx in enumerate(batch_indices): prompt_position = unique_prompt_indices.index(prompt_idx) rewards_by_prompt[prompt_position].append(rewards[i]) - # Calculate advantages within each group advantages = mx.zeros_like(rewards) for i, prompt_rewards in enumerate(rewards_by_prompt): - if len(prompt_rewards) > 1: # Only normalize if we have multiple samples + if len(prompt_rewards) > 1: prompt_rewards = mx.array(prompt_rewards) mean_reward = mx.mean(prompt_rewards) std_reward = mx.std(prompt_rewards) - - # Find indices for this prompt indices = [ j for j, idx in enumerate(batch_indices) @@ -545,7 +359,6 @@ def grpo_loss( std_reward + epsilon ) else: - # If only one sample, advantage is 0 idx = batch_indices.index(unique_prompt_indices[i]) advantages[idx] = 0.0 @@ -746,6 +559,7 @@ def evaluate_grpo( ref_model=ref_model, temperature=temperature, max_tokens=max_tokens, + is_validation=True ) all_losses += losses * toks @@ -803,21 +617,37 @@ def train_grpo( state = [model.state, optimizer.state] def step(batch): + # Extract prompt tokens from the batch + prompt_tokens, targets, prompt_lens, target_lens = batch + + # First, generate completions without gradient tracking + # The model will be frozen during this call + all_completions, all_completion_texts, batch_indices = generate_grpo( + model=model, + tokenizer=tokenizer, + prompt_tokens=prompt_tokens, + max_tokens=args.max_completion_length, + group_size=args.group_size, + temperature=args.temperature + ) + + # Now calculate loss and gradients with pre-generated completions + # We need to update loss_fn to accept these pre-generated completions (loss, toks, metrics), grad = loss_value_and_grad( model, tokenizer=tokenizer, - batch=batch, + batch=(prompt_tokens, targets, prompt_lens, target_lens), + completions=all_completions, + completion_texts=all_completion_texts, + batch_indices=batch_indices, reward_funcs=reward_funcs, beta=args.beta, group_size=args.group_size, epsilon=args.epsilon, ref_model=ref_model, - max_tokens=args.max_completion_length, - temperature=args.temperature, ) grad = average_gradients(grad) - optimizer.update(model, grad) return loss, toks, metrics