From 3dfb21267bdc9d5dff7ec3a3fd5718dc0df2fe05 Mon Sep 17 00:00:00 2001 From: Goekdeniz-Guelmez Date: Wed, 5 Mar 2025 12:59:41 +0100 Subject: [PATCH] updates --- llms/mlx_lm/tuner/grpo_reward_functions.py | 54 +- llms/mlx_lm/tuner/grpo_trainer.py | 675 ++++++++++++--------- 2 files changed, 432 insertions(+), 297 deletions(-) diff --git a/llms/mlx_lm/tuner/grpo_reward_functions.py b/llms/mlx_lm/tuner/grpo_reward_functions.py index 3b5c56b5..5f59da65 100644 --- a/llms/mlx_lm/tuner/grpo_reward_functions.py +++ b/llms/mlx_lm/tuner/grpo_reward_functions.py @@ -1,6 +1,5 @@ -from typing import List, Optional, Callable import re - +from typing import Callable, List, Optional RewardFunctions = Callable[[List[str], List[str], List[str]], List[float]] @@ -14,52 +13,73 @@ def r1_extract_xml_answer(text: str) -> str: print("r1_extract_xml_answer returned empty string") return "" -def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + +def r1_int_reward_func( + prompts: list, completions: list, answer: list, **kwargs +) -> list[float]: if not completions: return [0.0] * len(prompts) extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [0.5 if r and r.isdigit() else 0.0 for r in extracted_responses] -def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + +def r1_accuracy_reward_func( + prompts: list, completions: list, answer: list, **kwargs +) -> list[float]: if not completions or not answer: return [0.0] * len(prompts) extracted_responses = [r1_extract_xml_answer(r) for r in completions] - return [2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer)] + return [ + 2.0 if r and a and r == a else 0.0 for r, a in zip(extracted_responses, answer) + ] -def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + +def r1_soft_format_reward_func( + prompts: list, completions: list, answer: list, **kwargs +) -> list[float]: if not completions: return [0.0] * len(prompts) - + scores = [] for completion in completions: if not completion: scores.append(0.0) continue - + reason_start = completion.find("") reason_end = completion.find("") answer_start = completion.find("") answer_end = completion.find("") - - if (reason_start != -1 and reason_end != -1 and - answer_start != -1 and answer_end != -1 and - reason_start < reason_end < answer_start < answer_end): - reason_content = completion[reason_start+13:reason_end].strip() - answer_content = completion[answer_start+8:answer_end].strip() + + if ( + reason_start != -1 + and reason_end != -1 + and answer_start != -1 + and answer_end != -1 + and reason_start < reason_end < answer_start < answer_end + ): + reason_content = completion[reason_start + 13 : reason_end].strip() + answer_content = completion[answer_start + 8 : answer_end].strip() if reason_content and answer_content: scores.append(0.5) continue scores.append(0.0) return scores -def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + +def r1_strict_format_reward_func( + prompts: list, completions: list, answer: list, **kwargs +) -> list[float]: if not completions: return [0.0] * len(prompts) pattern = r" .*? .*? " matches = [bool(re.search(pattern, r)) if r else False for r in completions] return [0.5 if match else 0.0 for match in matches] -def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: + +def r1_count_xml( + prompts: list, completions: list, answer: list, **kwargs +) -> list[float]: if not completions: return [0.0] * len(prompts) scores = [] @@ -79,4 +99,4 @@ def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> li end_text = text.split("")[-1] count -= len(end_text) * 0.001 if len(end_text) > 0 else 0 scores.append(max(0.0, count)) - return scores \ No newline at end of file + return scores diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index 3083e785..a9584bda 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -1,19 +1,28 @@ # Copyright Ā© 2024 Apple Inc. -from typing import List, Optional, Tuple, Generator, Callable, Any +import time from dataclasses import dataclass, field from pathlib import Path -import time +from typing import Any, Callable, Generator, List, Optional, Tuple -from mlx.utils import tree_flatten import mlx.core as mx import mlx.nn as nn import numpy as np +from mlx.utils import tree_flatten -from .grpo_reward_functions import r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml,r1_extract_xml_answer, RewardFunctions -from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients -from ..utils import generate_step from ..models import cache +from ..utils import generation_stream +from .grpo_reward_functions import ( + RewardFunctions, + r1_accuracy_reward_func, + r1_count_xml, + r1_extract_xml_answer, + r1_int_reward_func, + r1_soft_format_reward_func, + r1_strict_format_reward_func, +) +from .trainer import TrainingArgs, TrainingCallback, average_gradients, grad_checkpoint + @dataclass class GRPOTrainingArgs(TrainingArgs): @@ -21,9 +30,7 @@ class GRPOTrainingArgs(TrainingArgs): default=4, metadata={"help": "Number of responses per prompt."}, ) - beta: float = field( - default=0.1, metadata={"help": "KL penalty coefficient."} - ) + beta: float = field(default=0.1, metadata={"help": "KL penalty coefficient."}) epsilon: float = field( default=1e-4, metadata={"help": "The Epsilon for numerical stability."} ) @@ -34,168 +41,237 @@ class GRPOTrainingArgs(TrainingArgs): default=None, metadata={ "help": "Path to reference model weights. If None, uses the same model." - } + }, ) temperature: float = field( default=1.0, metadata={ "help": "Temperature for sampling. The higher the temperature, the more random the completions." - } + }, ) reward_weights: Optional[List[float]] = field( default=None, metadata={ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`." - } + }, ) +def generate_step( + prompt: mx.array, + model: nn.Module, + *, + max_tokens: int = 256, + sampler: Optional[Callable[mx.array, mx.array]] = None, + logits_processors: Optional[List[Callable[[mx.array, mx.array], mx.array]]] = None, + max_kv_size: Optional[int] = None, + prompt_cache: Optional[Any] = None, + prefill_step_size: int = 512, + prompt_progress_callback: Optional[Callable[int, int]] = None, +) -> Generator[Tuple[mx.array, mx.array], None, None]: + """ + A generator producing token ids based on the given prompt from the model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + max_tokens (int): The maximum number of tokens. Use``-1`` for an infinite + generator. Default: ``256``. + sampler (Callable[mx.array, mx.array], optional): A sampler for sampling a + token from a vector of log probabilities. Default: ``None``. + logits_processors (List[Callable[[mx.array, mx.array], mx.array]], optional): + A list of functions that take tokens and logits and return the processed + logits. Default: ``None``. + max_kv_size (int, optional): Maximum size of the key-value cache. Old + entries (except the first 4 tokens) will be overwritten. + prompt_cache (List[Any], optional): A pre-computed prompt cache. Note, if + provided, the cache will be updated in place. + prefill_step_size (int): Step size for processing the prompt. + kv_bits (int, optional): Number of bits to use for KV cache quantization. + None implies no cache quantization. Default: ``None``. + kv_group_size (int): Group size for KV cache quantization. Default: ``64``. + quantized_kv_start (int): Step to begin using a quantized KV cache. + when ``kv_bits`` is non-None. Default: ``0``. + prompt_prorgress_callback (Callable[int, int]): A call-back which takes the + prompt tokens processed so far and the total number of prompt tokens. + + Yields: + Tuple[mx.array, mx.array]: One token and a vector of log probabilities. + """ + + y = prompt + tokens = None + + # Create the KV cache for generation + if prompt_cache is None: + prompt_cache = cache.make_prompt_cache( + model, + max_kv_size=max_kv_size, + ) + elif len(prompt_cache) != len(model.layers): + raise ValueError("Wrong number of layers in the prompt cache.") + + prompt_progress_callback = prompt_progress_callback or (lambda *_: None) + + sampler = sampler or (lambda x: mx.argmax(x, axis=-1)) + + def _step(y): + with mx.stream(generation_stream): + logits = model(y[None], cache=prompt_cache) + logits = logits[:, -1, :] + + if logits_processors: + nonlocal tokens + tokens = mx.concat([tokens, y]) if tokens is not None else y + + for processor in logits_processors: + logits = processor(tokens, logits) + + logprobs = logits - mx.logsumexp(logits, keepdims=True) + y = sampler(logprobs) + return y, logprobs.squeeze(0) + + with mx.stream(generation_stream): + total_prompt_tokens = y.size + prompt_processed_tokens = 0 + while y.size > prefill_step_size: + model(y[:prefill_step_size][None], cache=prompt_cache) + mx.eval([c.state for c in prompt_cache]) + prompt_progress_callback(prompt_processed_tokens, total_prompt_tokens) + prompt_processed_tokens += prefill_step_size + y = y[prefill_step_size:] + mx.metal.clear_cache() + + y, logprobs = _step(y) + + mx.eval(y, logprobs) + n = 0 + while True: + if n != max_tokens: + next_y, next_logprobs = _step(y) + mx.eval(next_y, next_logprobs) + if n == 0: + mx.eval(y) + prompt_progress_callback(total_prompt_tokens, total_prompt_tokens) + if n == max_tokens: + break + yield y.item(), logprobs + if n % 256 == 0: + mx.metal.clear_cache() + y, logprobs = next_y, next_logprobs + n += 1 + + def generate_grpo( model: nn.Module, prompts, max_tokens, tokenizer, group_size, - is_training=False, end_token: str = "", temperature: float = 0.8, - batch_size: int = 1 + batch_size: int = 1, ): - # Store original training state - was_training = model.training - - # Set model to eval mode for generation - model.eval() - try: + import time + + start_time = time.time() + if len(prompts.shape) == 1: prompts = prompts[None, :] if prompts.shape[1] == 0: return None - + total_samples = prompts.shape[0] * group_size expanded_prompts = mx.repeat(prompts, group_size, axis=0) end_sequence = mx.array(tokenizer.encode(end_token)) results = [] mx.eval(expanded_prompts) - + + print(f"Setup time: {time.time() - start_time:.2f}s") + print(f"Generating {total_samples} samples with max_tokens={max_tokens}") + + total_tokens_generated = 0 + generation_start_time = time.time() + # Process in batches for batch_start in range(0, total_samples, batch_size): batch_end = min(batch_start + batch_size, total_samples) - - if is_training: - # Training-specific generation logic - batch_inputs = expanded_prompts[batch_start:batch_end] - batch_tokens = [[] for _ in range(batch_end - batch_start)] - prompt_caches = [cache.make_prompt_cache(model) for _ in range(batch_end - batch_start)] - - # Initial forward pass - for i, prompt in enumerate(batch_inputs): - logits = model(prompt[None], cache=prompt_caches[i])[:, -1] - logits_temp = logits / temperature - next_token = mx.random.categorical(logits_temp) - token = next_token.item() - batch_tokens[i].append(token) - del logits, logits_temp, next_token - - mx.eval([tokens[-1] for tokens in batch_tokens]) - mx.metal.clear_cache() - - active_indices = [i for i in range(len(batch_tokens)) if batch_tokens[i][-1] != tokenizer.eos_token_id] - - # Generate remaining tokens - for _ in range(max_tokens - 1): - if not active_indices: - break - - next_active = [] - for idx in active_indices: - current_input = mx.array([batch_tokens[idx][-1]]) - logits = model(current_input[None], cache=prompt_caches[idx])[:, -1] - logits_temp = logits / temperature - next_token = mx.random.categorical(logits_temp) - token = next_token.item() - batch_tokens[idx].append(token) - - # Check for end conditions - is_end = False - if len(batch_tokens[idx]) >= len(end_sequence): - test_sequence = batch_tokens[idx][-len(end_sequence):] - is_end = mx.array_equal(mx.array(test_sequence), end_sequence) - - if not (is_end or token == tokenizer.eos_token_id): - next_active.append(idx) - - del logits, logits_temp, next_token, current_input - - mx.eval([tokens[-1] for tokens in batch_tokens]) - mx.metal.clear_cache() - active_indices = next_active - - # Clean up caches - for pc in prompt_caches: - del pc - - # Process results - for tokens in batch_tokens: - if tokens: - # Truncate at end token if present - for i in range(len(tokens) - len(end_sequence) + 1): - if mx.array_equal( - mx.array(tokens[i:i+len(end_sequence)]), - end_sequence - ): - tokens = tokens[:i+len(end_sequence)] - break - - if tokens and tokens[-1] == tokenizer.eos_token_id: - tokens = tokens[:-1] - - if tokens: - results.append(mx.array(tokens)) - - del batch_inputs, batch_tokens, prompt_caches - mx.metal.clear_cache() - else: - # Non-training mode with batched processing - for idx in range(batch_start, batch_end): - current_tokens = [] - generator = generate_step( + batch_time = time.time() + print( + f"Starting batch {batch_start//batch_size + 1}/{(total_samples + batch_size - 1)//batch_size}: samples {batch_start}-{batch_end-1}" + ) + + # Custom sampler function that handles temperature + def temp_sampler(logits): + return mx.random.categorical(logits / temperature) + + # Batched processing + for idx in range(batch_start, batch_end): + sample_start_time = time.time() + current_tokens = [] + prompt_cache = cache.make_prompt_cache(model) + + # The generate_step function yields one token at a time + # We'll collect tokens until we hit max_tokens or a stopping condition + for i, (token, _) in enumerate( + generate_step( expanded_prompts[idx], model, - max_tokens=max_tokens, - sampler=lambda x: mx.random.categorical(x / temperature) + max_tokens=max_tokens, # This is the maximum number of steps + sampler=temp_sampler, + prompt_cache=prompt_cache, ) - - for token, _ in generator: - test_sequence = current_tokens + [token] - if (len(test_sequence) >= len(end_sequence) and - mx.array_equal( - mx.array(test_sequence[-len(end_sequence):]), - end_sequence - )): - current_tokens.append(token) - break - - if token == tokenizer.eos_token_id: - break - current_tokens.append(token) - - if current_tokens: - results.append(mx.array(current_tokens)) - + ): + print(token) + current_tokens.append(token) + + # Check for end token + if len(current_tokens) >= len(end_sequence) and mx.array_equal( + mx.array(current_tokens[-len(end_sequence) :]), end_sequence + ): + break + + # Check for EOS token + if token == tokenizer.eos_token_id: + break + + # Check if we've reached the maximum number of tokens + if i >= max_tokens - 1: + break + + if current_tokens: + results.append(mx.array(current_tokens)) + total_tokens_generated += len(current_tokens) + + sample_time = time.time() - sample_start_time + tokens_per_second = ( + len(current_tokens) / sample_time if sample_time > 0 else 0 + ) + print( + f" Sample {idx}: Generated {len(current_tokens)} tokens in {sample_time:.2f}s ({tokens_per_second:.2f} tokens/sec)" + ) + + batch_time = time.time() - batch_time + print(f"Batch completed in {batch_time:.2f}s") mx.metal.clear_cache() - + + generation_time = time.time() - generation_start_time + avg_tokens_per_second = ( + total_tokens_generated / generation_time if generation_time > 0 else 0 + ) + + print( + f"Generation complete: {total_tokens_generated} tokens in {generation_time:.2f}s" + ) + print(f"Average generation speed: {avg_tokens_per_second:.2f} tokens/sec") + mx.eval(results) return results except Exception as e: print(f"Generation error: {str(e)}") return None - - finally: - # Don't restore training mode - let the caller handle it - pass def get_per_token_logps(model: nn.Module, inputs, lengths): @@ -209,8 +285,7 @@ def get_per_token_logps(model: nn.Module, inputs, lengths): seq_targets = targets[i, :seq_len] log_probs = nn.log_softmax(seq_logits, axis=-1) token_log_probs = mx.take_along_axis( - log_probs, - seq_targets.reshape(seq_len, 1), axis=-1 + log_probs, seq_targets.reshape(seq_len, 1), axis=-1 ).squeeze(-1) per_token_logps.append(token_log_probs) mx.eval(logits) @@ -223,14 +298,14 @@ def grpo_loss( tokenizer, batch, reward_funcs: Optional[List[RewardFunctions]] = None, - beta: float =0.1, + beta: float = 0.1, group_size: int = 4, epsilon: float = 1e-4, max_tokens: int = 64, temperature: float = 0.8, reward_weights: Optional[List[float]] = None, is_validation: bool = False, - batch_size: int = 1 + batch_size: int = 1, ): prompt_tokens, _, prompt_text, answer_text = batch total_samples = len(prompt_tokens) @@ -238,53 +313,48 @@ def grpo_loss( all_completions = [] all_completion_texts = [] batch_indices = [] # Keep track of which batch each completion belongs to - + + # Store original training state + was_training = model.training + + # Set model to eval mode for generation + model.eval() + # Process in smaller batches for i in range(0, total_samples, batch_size): # Get actual batch size for this iteration (might be smaller for the last batch) current_batch_size = min(batch_size, total_samples - i) - batch_prompts = prompt_tokens[i:i+current_batch_size] - + batch_prompts = prompt_tokens[i : i + current_batch_size] + # Pad sequences to the same length max_prompt_len = max(len(p) for p in batch_prompts) padded_prompts = [] - + for prompt in batch_prompts: padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt)) padded_prompts.append(prompt + padding) - + # Convert to tensor prompt_tensor = mx.array(padded_prompts) - + try: - if is_validation: - completions = generate_grpo( - model, - prompt_tensor, - max_tokens, - tokenizer, - group_size, - temperature=temperature, - batch_size=current_batch_size - ) - model.train() - else: - completions = generate_grpo( - model, - prompt_tensor, - max_tokens, - tokenizer, - group_size, - is_training=True, - temperature=temperature, - batch_size=current_batch_size - ) - + completions = generate_grpo( + model, + prompt_tensor, + max_tokens, + tokenizer, + group_size, + temperature=temperature, + batch_size=current_batch_size, + ) + if completions is not None: for j, completion_ids in enumerate(completions): # Calculate which prompt this completion belongs to prompt_idx = i + (j // group_size) - if prompt_idx < total_samples: # Make sure we don't go out of bounds + if ( + prompt_idx < total_samples + ): # Make sure we don't go out of bounds batch_indices.append(prompt_idx) completion_text = tokenizer.decode(completion_ids.tolist()) all_completions.append(completion_ids) @@ -293,65 +363,74 @@ def grpo_loss( except Exception as e: print(f"Generation error: {e}") continue - + + # Restore original training state if we're not in validation mode + if not is_validation and was_training: + model.train() + mx.metal.clear_cache() # If we didn't generate any completions, return early if not all_completions: - raise ValueError("No completions were generated. Please check your model and inputs.") + raise ValueError( + "No completions were generated. Please check your model and inputs." + ) + # The rest of the function remains the same # Create expanded prompts and answers based on actual generated completions expanded_answers = [] expanded_prompts = [] - + # Group completions by their original prompt unique_prompt_indices = sorted(set(batch_indices)) grouped_completions = {idx: [] for idx in unique_prompt_indices} - + for i, completion_idx in enumerate(batch_indices): grouped_completions[completion_idx].append(i) - + # Rebuild completions in the correct order ordered_completions = [] ordered_completion_texts = [] ordered_batch_indices = [] - + for prompt_idx in unique_prompt_indices: completion_indices = grouped_completions[prompt_idx] for idx in completion_indices: ordered_completions.append(all_completions[idx]) ordered_completion_texts.append(all_completion_texts[idx]) ordered_batch_indices.append(prompt_idx) - + # Add corresponding prompt and answer expanded_prompts.append(prompt_text[prompt_idx]) expanded_answers.append(answer_text[prompt_idx]) - + all_completions = ordered_completions all_completion_texts = ordered_completion_texts batch_indices = ordered_batch_indices - + # Continue with the rest of the function max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] attention_masks = [] - + for completion_ids in all_completions: padding_length = max_length - completion_ids.shape[0] if padding_length > 0: padding = mx.zeros((padding_length,), dtype=completion_ids.dtype) padded_ids = mx.concatenate([completion_ids, padding]) - mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)]) + mask = mx.concatenate( + [mx.ones_like(completion_ids), mx.zeros_like(padding)] + ) else: padded_ids = completion_ids mask = mx.ones_like(completion_ids) padded_completions.append(padded_ids) attention_masks.append(mask) - + inputs = mx.stack(padded_completions) attention_mask = mx.stack(attention_masks) lengths = attention_mask.sum(axis=1) - + # Current policy probabilities token_log_probs = get_per_token_logps(model, inputs, lengths) mx.eval(token_log_probs) @@ -361,7 +440,7 @@ def grpo_loss( else: ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) mx.eval(ref_token_log_probs) - + max_len = max(x.shape[0] for x in token_log_probs) padded_log_probs = [] padded_ref_log_probs = [] @@ -369,23 +448,25 @@ def grpo_loss( for i in range(len(token_log_probs)): seq_len = token_log_probs[i].shape[0] padding = mx.zeros((max_len - seq_len,)) - + padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) - + token_log_probs = mx.stack(padded_log_probs) ref_token_log_probs = mx.stack(padded_ref_log_probs) - + # Create array to store rewards from each function all_func_rewards = [] # Collect rewards from each function separately for reward_func in reward_funcs: - func_rewards = mx.array(reward_func( - prompts=expanded_prompts, - completions=all_completion_texts, - answer=expanded_answers - )) + func_rewards = mx.array( + reward_func( + prompts=expanded_prompts, + completions=all_completion_texts, + answer=expanded_answers, + ) + ) all_func_rewards.append(func_rewards) # Stack rewards to shape (num_samples, num_funcs) @@ -401,18 +482,18 @@ def grpo_loss( reward_weights = mx.array(reward_weights, dtype=mx.float32) else: reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32) - + rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) # Get number of unique prompts num_unique_prompts = len(unique_prompt_indices) - + # Reshape rewards based on actual groups rewards_by_prompt = [[] for _ in range(num_unique_prompts)] for i, prompt_idx in enumerate(batch_indices): prompt_position = unique_prompt_indices.index(prompt_idx) rewards_by_prompt[prompt_position].append(rewards[i]) - + # Calculate advantages within each group advantages = mx.zeros_like(rewards) for i, prompt_rewards in enumerate(rewards_by_prompt): @@ -420,33 +501,47 @@ def grpo_loss( prompt_rewards = mx.array(prompt_rewards) mean_reward = mx.mean(prompt_rewards) std_reward = mx.std(prompt_rewards) - + # Find indices for this prompt - indices = [j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i]] + indices = [ + j + for j, idx in enumerate(batch_indices) + if idx == unique_prompt_indices[i] + ] for j, idx in enumerate(indices): - advantages[idx] = (prompt_rewards[j] - mean_reward) / (std_reward + epsilon) + advantages[idx] = (prompt_rewards[j] - mean_reward) / ( + std_reward + epsilon + ) else: # If only one sample, advantage is 0 idx = batch_indices.index(unique_prompt_indices[i]) advantages[idx] = 0.0 - + # Compute KL divergence using Schulman's approximator - kl_div = mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1 - + kl_div = ( + mx.exp(ref_token_log_probs - token_log_probs) + - (ref_token_log_probs - token_log_probs) + - 1 + ) + # Create mask for valid tokens length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) - + # Compute policy ratio - policy_ratio = mx.exp(mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))) - + policy_ratio = mx.exp( + mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)) + ) + # Compute per-token loss - per_token_loss = -((policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask) - + per_token_loss = -( + (policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask + ) + # Average over tokens sequence_sums = per_token_loss.sum(axis=1) sequence_lengths = length_mask.sum(axis=1) loss = (sequence_sums / sequence_lengths).mean() - + # Calculate mean KL divergence for metrics mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() @@ -454,70 +549,82 @@ def grpo_loss( reward_metrics = {} for i, reward_func in enumerate(reward_funcs): func_name = reward_func.__name__ - func_rewards = mx.array(reward_func( - prompts=expanded_prompts, - completions=all_completion_texts, - answer=expanded_answers - )) - reward_metrics[f'{func_name}_mean'] = mx.mean(func_rewards) - reward_metrics[f'{func_name}_std'] = mx.std(func_rewards) + func_rewards = mx.array( + reward_func( + prompts=expanded_prompts, + completions=all_completion_texts, + answer=expanded_answers, + ) + ) + reward_metrics[f"{func_name}_mean"] = mx.mean(func_rewards) + reward_metrics[f"{func_name}_std"] = mx.std(func_rewards) + grouped_rewards_mean = mx.array( + [mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt] + ) + grouped_rewards_std = mx.array( + [ + mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) + for rewards in rewards_by_prompt + ] + ) - grouped_rewards_mean = mx.array([mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt]) - grouped_rewards_std = mx.array([mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt]) - metrics = { - 'total_rewards_mean': mx.mean(rewards), - 'total_rewards_std': mx.std(rewards), - 'grouped_rewards_mean': mx.mean(grouped_rewards_mean), - 'grouped_rewards_std': mx.mean(grouped_rewards_std), - 'kl': mean_kl, - **reward_metrics + "total_rewards_mean": mx.mean(rewards), + "total_rewards_std": mx.std(rewards), + "grouped_rewards_mean": mx.mean(grouped_rewards_mean), + "grouped_rewards_std": mx.mean(grouped_rewards_std), + "kl": mean_kl, + **reward_metrics, } if is_validation and all_completion_texts: print("\n=== Validation Sample Details ===") - + # Print the input context (prompt) last_prompt_idx = batch_indices[-1] if batch_indices else 0 - + if last_prompt_idx < len(prompt_text): print(f"\nšŸ“‹ Raw Prompt:\n{prompt_text[last_prompt_idx]}") - print("\n" + "="*10 + "\n") - + print("\n" + "=" * 10 + "\n") + # Get the actual tokenized prompt that was fed to the model if last_prompt_idx < len(prompt_tokens): actual_prompt = tokenizer.decode(prompt_tokens[last_prompt_idx]) print(f"\nšŸ”„ Model Input:\n{actual_prompt}") - print("\n" + "="*10 + "\n") - + print("\n" + "=" * 10 + "\n") + print(f"\nšŸ“ Generation:\n{all_completion_texts[-1]}") - print("\n" + "="*10 + "\n") - + print("\n" + "=" * 10 + "\n") + # Make sure we have a valid index for answer_text if last_prompt_idx < len(answer_text): print(f"\nāœ… Answer:\n{answer_text[last_prompt_idx]}") - print("\n" + "="*10 + "\n") - + print("\n" + "=" * 10 + "\n") + # Only try to extract if r1_extract_xml_answer is defined - if 'r1_extract_xml_answer' in globals(): - print(f"\nšŸ” Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}") - print("\n" + "="*35 + "\n") - + if "r1_extract_xml_answer" in globals(): + print( + f"\nšŸ” Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}" + ) + print("\n" + "=" * 35 + "\n") + mx.metal.clear_cache() - + return loss, sequence_lengths.sum(), metrics def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False): if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: - raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") + raise ValueError( + "Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples" + ) def length_key(i): return len(dataset[i][0]) + len(dataset[i][1]) - + idx = sorted(range(len(dataset)), key=length_key) - + if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size} " @@ -534,13 +641,14 @@ def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False): while True: indices = ( - np.random.permutation(list(batch_index_generator())) if train + np.random.permutation(list(batch_index_generator())) + if train else batch_index_generator() ) - + for batch_idx in indices: current_batch = [dataset[j] for j in batch_idx] - + prompts_tokens = [item[0] for item in current_batch] answers_tokens = [item[1] for item in current_batch] prompts_text = [item[2] for item in current_batch] @@ -576,17 +684,17 @@ def evaluate_grpo( r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, - r1_count_xml + r1_count_xml, ], loss_fn: callable = grpo_loss, - iterate_batches: callable = iterate_grpo_batches + iterate_batches: callable = iterate_grpo_batches, ): all_losses = 0 ntokens = 0 all_metrics = None - + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - + for _, batch in zip( index_iterator, iterate_batches( @@ -606,27 +714,27 @@ def evaluate_grpo( ref_model=ref_model, temperature=temperature, max_tokens=max_tokens, - is_validation=True + is_validation=True, ) - + all_losses += losses * toks ntokens += toks - + if all_metrics is None: all_metrics = {k: v * toks for k, v in metrics.items()} else: for k, v in metrics.items(): all_metrics[k] += v * toks - + mx.eval(all_losses, ntokens) - + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} - + avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} avg_loss = (all_losses / ntokens).item() - + return avg_loss, ntokens, avg_metrics @@ -642,14 +750,16 @@ def train_grpo( r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, - r1_count_xml + r1_count_xml, ], args: GRPOTrainingArgs = GRPOTrainingArgs(), loss_fn: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches, training_callback: TrainingCallback = None, ): - print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}") + print( + f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}" + ) world = mx.distributed.init() world_size = world.size() rank = world.rank() @@ -663,7 +773,7 @@ def train_grpo( def step(batch): (loss, toks, metrics), grad = loss_value_and_grad( - model, + model, tokenizer=tokenizer, batch=batch, reward_funcs=reward_funcs, @@ -672,7 +782,7 @@ def train_grpo( epsilon=args.epsilon, ref_model=ref_model, max_tokens=args.max_completion_length, - temperature=args.temperature + temperature=args.temperature, ) grad = average_gradients(grad) @@ -680,7 +790,7 @@ def train_grpo( optimizer.update(model, grad) return loss, toks, metrics - + loss_value_and_grad = nn.value_and_grad(model, loss_fn) losses = 0 @@ -688,16 +798,16 @@ def train_grpo( steps = 0 trained_tokens = 0 accumulated_metrics = { - 'total_rewards_mean': 0, - 'total_rewards_std': 0, - 'grouped_rewards_mean': 0, - 'grouped_rewards_std': 0, - 'kl': 0 + "total_rewards_mean": 0, + "total_rewards_std": 0, + "grouped_rewards_mean": 0, + "grouped_rewards_std": 0, + "kl": 0, } for reward_func in reward_funcs: func_name = reward_func.__name__ - accumulated_metrics[f'{func_name}_mean'] = 0 - accumulated_metrics[f'{func_name}_std'] = 0 + accumulated_metrics[f"{func_name}_mean"] = 0 + accumulated_metrics[f"{func_name}_std"] = 0 start = time.perf_counter() for it, batch in zip( @@ -738,26 +848,27 @@ def train_grpo( f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, " f"Val kl {val_metrics['kl']:.3f}" ) - + for i, reward_func in enumerate(reward_funcs): val_metrics_str += ( f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, " f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}" ) - + print( - f"Iter {it}: {val_metrics_str}, " - f"Val took {val_time:.3f}s", + f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s", flush=True, ) if training_callback is not None: - training_callback.on_val_loss_report({ - "iteration": it, - "val_loss": val_loss, - **{f"val_{k}": v for k, v in val_metrics.items()}, - "val_time": val_time, - }) + training_callback.on_val_loss_report( + { + "iteration": it, + "val_loss": val_loss, + **{f"val_{k}": v for k, v in val_metrics.items()}, + "val_time": val_time, + } + ) start = time.perf_counter() @@ -776,7 +887,9 @@ def train_grpo( train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() - avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()} + avg_metrics = { + k: v / (steps * world_size) for k, v in accumulated_metrics.items() + } n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) @@ -793,7 +906,7 @@ def train_grpo( f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, " f"KL {avg_metrics['kl']:.3f}" ) - + for i, reward_func in enumerate(reward_funcs): func_name = reward_func.__name__ train_metrics_str += ( @@ -811,16 +924,18 @@ def train_grpo( ) if training_callback is not None: - training_callback.on_train_loss_report({ - "iteration": it, - "train_loss": train_loss, - **{f"train_{k}": v for k, v in avg_metrics.items()}, - "learning_rate": learning_rate, - "iterations_per_second": it_sec, - "tokens_per_second": tokens_sec, - "trained_tokens": trained_tokens, - "peak_memory": peak_mem, - }) + training_callback.on_train_loss_report( + { + "iteration": it, + "train_loss": train_loss, + **{f"train_{k}": v for k, v in avg_metrics.items()}, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + "peak_memory": peak_mem, + } + ) losses = 0 n_tokens = 0 @@ -841,4 +956,4 @@ def train_grpo( adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) - print(f"Saved final weights to {args.adapter_file}.") \ No newline at end of file + print(f"Saved final weights to {args.adapter_file}.")