# Copyright © 2024 Apple Inc. import time from dataclasses import dataclass, field from pathlib import Path import re import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.utils import tree_flatten from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients, iterate_batches from mlx_lm.utils import generate_step @dataclass class GRPOTrainingArgs(TrainingArgs): group_size: int = field( default=4, metadata={"help": "Number of responses per prompt."}, ) beta: float = field( default=0.1, metadata={"help": "KL penalty coefficient."} ) epsilon: float = field( default=1e-4, metadata={"help": "The Epsilon for numerical stability."} ) reference_model_path: str = field( default=None, metadata={ "help": "Path to reference model weights. If None, uses the same model." } ) def generate_for_grpo( model, prompt, max_tokens, tokenizer, temperature=1.0 ): try: # Ensure prompt is the right shape if len(prompt.shape) == 1: prompt = prompt[None, :] # Initialize generation generated = [] current_prompt = prompt[0] for step in range(max_tokens): try: # Get model output with explicit shape checking current_batch = current_prompt[None, :] logits = model(current_batch) # Ensure we have the last token logits token_logits = logits[0, -1] # Apply temperature and get probabilities if temperature > 0: token_logits = token_logits / temperature probs = mx.softmax(token_logits) # Sample the next token next_token = mx.random.categorical(probs[None, :]) next_token = next_token[0] # Force evaluation to catch any issues mx.eval(next_token) token_value = next_token.item() # Add to generated sequence generated.append(next_token) current_prompt = mx.concatenate([current_prompt, next_token[None]]) if token_value == tokenizer.eos_token_id: break except Exception as e: raise if not generated: return prompt[0] try: result = mx.concatenate([prompt[0], mx.stack(generated)]) mx.eval(result) return result except Exception as e: raise except Exception as e: raise def r1_extract_xml_answer(text: str) -> str: """Extracts the answer from an XML formatted text string.""" try: answer = text.split("")[-1] answer = answer.split("")[0] return answer.strip() except: print("[extract_xml_answer] Failed to extract answer from: ", text) return "" def r1_accuracy_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Calculates reward based on accuracy of extracted answers. Args: prompts: List of input prompts completions: List of completion strings answer: Expected answer or list of answers **kwargs: Additional arguments Returns: list[float]: Reward values for each completion """ extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)] def r1_int_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Rewards numerical responses. Args: prompts: List of input prompts completions: List of completion strings answer: Expected answer or list of answers **kwargs: Additional arguments Returns: list[float]: Reward values for each completion """ extracted_responses = [r1_extract_xml_answer(r) for r in completions] return [0.5 if r.isdigit() else 0.0 for r in extracted_responses] def r1_soft_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Rewards completions with flexible XML format.""" pattern = r".*?\s*.*?" matches = [re.match(pattern, r) for r in completions] return [0.5 if match else 0.0 for match in matches] def r1_strict_format_reward_func(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Rewards completions with strict XML format. Args: prompts: List of input prompts completions: List of completion strings answer: Expected answer or list of answers **kwargs: Additional arguments Returns: list[float]: Reward values for each completion """ pattern = r"^\n.*?\n\n\n.*?\n\n$" matches = [re.match(pattern, r) for r in completions] return [0.5 if match else 0.0 for match in matches] def r1_count_xml(prompts: list, completions: list, answer: list, **kwargs) -> list[float]: """Calculates score based on XML formatting. Args: prompts: List of input prompts (unused) completions: List of completion strings to evaluate answer: Expected answer or list of answers (unused) **kwargs: Additional arguments Returns: list[float]: List of scores based on XML tag presence and formatting """ scores = [] for text in completions: count = 0.0 if text.count("\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 if text.count("\n\n") == 1: count += 0.125 count -= len(text.split("\n\n")[-1])*0.001 if text.count("\n") == 1: count += 0.125 count -= (len(text.split("\n")[-1]) - 1)*0.001 scores.append(count) return scores def grpo_loss( model, tokenizer, batch, reward_funcs=None, beta=0.1, group_size=4, epsilon=1e-4, ref_model=None, max_tokens=128, temperature=1.0 ): """Modified GRPO loss function with better error handling""" prompt_tokens, answer_tokens, prompt_text, answer_text = batch batch_size = len(prompt_tokens) # Generate completions for each prompt all_completions = [] all_completion_texts = [] for prompt in prompt_tokens: prompt_tensor = mx.array(prompt) prompt_completions = [] prompt_completion_texts = [] # Generate group_size completions for each prompt for _ in range(group_size): try: completion_ids = generate_for_grpo( model, prompt_tensor, max_tokens, tokenizer=tokenizer, temperature=temperature ) # Verify completion_ids is not None if completion_ids is None: print("Warning: generate_for_grpo returned None") break completion_text = tokenizer.decode(completion_ids.tolist()) prompt_completions.append(completion_ids) prompt_completion_texts.append(completion_text) except Exception as e: print(f"Error in completion generation: {str(e)}") # Fallback to using original prompt prompt_completions.append(prompt_tensor) prompt_completion_texts.append(tokenizer.decode(prompt_tensor.tolist())) all_completions.extend(prompt_completions) all_completion_texts.extend(prompt_completion_texts) # Verify we have the expected number of completions assert len(all_completions) == batch_size * group_size assert len(all_completion_texts) == batch_size * group_size # Expand answer_text and prompt_text to match completion groups expanded_answers = [] expanded_prompts = [] for i in range(batch_size): expanded_answers.extend([answer_text[i]] * group_size) expanded_prompts.extend([prompt_text[i]] * group_size) # Verify we have the expected number of completions assert len(all_completions) == batch_size * group_size assert len(all_completion_texts) == batch_size * group_size max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] attention_masks = [] for completion_ids in all_completions: padding_length = max_length - completion_ids.shape[0] if padding_length > 0: padding = mx.zeros((padding_length,), dtype=completion_ids.dtype) padded_ids = mx.concatenate([completion_ids, padding]) mask = mx.concatenate([mx.ones_like(completion_ids), mx.zeros_like(padding)]) else: padded_ids = completion_ids mask = mx.ones_like(completion_ids) padded_completions.append(padded_ids) attention_masks.append(mask) inputs = mx.stack(padded_completions) attention_mask = mx.stack(attention_masks) lengths = attention_mask.sum(axis=1) # Get logits from current model logits = model(inputs).astype(mx.float32) # Calculate log probabilities log_probs = nn.log_softmax(logits[:, :-1, :], axis=-1) # Prepare targets targets = inputs[:, 1:] # Gather actual token probabilities token_log_probs = mx.take_along_axis( log_probs, targets.reshape(*targets.shape, 1), axis=-1 ).squeeze(-1) # Get reference model log probabilities if ref_model is not None: ref_logits = ref_model(inputs).astype(mx.float32) else: ref_logits = model(inputs).astype(mx.float32) ref_log_probs = nn.log_softmax(ref_logits[:, :-1, :], axis=-1) ref_token_log_probs = mx.take_along_axis( ref_log_probs, targets.reshape(*targets.shape, 1), axis=-1 ).squeeze(-1) # Compute KL divergence kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1) # Calculate combined rewards from all reward functions rewards = mx.zeros((len(all_completions),)) for reward_func in reward_funcs: func_rewards = mx.array(reward_func( prompts=prompt_text, completions=all_completion_texts, answer=answer_text )) rewards += func_rewards # Normalize rewards if using multiple reward functions if len(reward_funcs) > 1: rewards /= len(reward_funcs) # Compute grouped-wise rewards grouped_rewards = rewards.reshape(batch_size, group_size) mean_grouped_rewards = mx.mean(grouped_rewards, axis=1) std_grouped_rewards = mx.std(grouped_rewards, axis=1) # Normalize rewards to compute advantages mean_grouped_rewards = mx.repeat(mean_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) std_grouped_rewards = mx.repeat(std_grouped_rewards.reshape(-1, 1), group_size, axis=1).reshape(-1) advantages = (rewards - mean_grouped_rewards) / (std_grouped_rewards + epsilon) # Create length mask for the shifted sequence length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) # Calculate policy gradient loss per_token_loss = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) * advantages.reshape(-1, 1) per_token_loss = -(per_token_loss - beta * kl_div) # Normalize loss properly per sequence sequence_sums = (per_token_loss * length_mask).sum(axis=1) sequence_lengths = length_mask.sum(axis=1) loss = (sequence_sums / sequence_lengths).mean() # Calculate mean KL divergence mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() # Collect metrics for each reward function separately reward_metrics = {} for i, reward_func in enumerate(reward_funcs): func_rewards = mx.array(reward_func( prompts=prompt_text, completions=all_completion_texts, answer=answer_text )) # func_grouped_rewards = func_rewards.reshape(batch_size, group_size) reward_metrics[f'reward_func_{i}_mean'] = mx.mean(func_rewards) reward_metrics[f'reward_func_{i}_std'] = mx.std(func_rewards) metrics = { 'total_rewards_mean': mx.mean(rewards), 'total_rewards_std': mx.std(rewards), 'grouped_rewards_mean': mx.mean(grouped_rewards), 'grouped_rewards_std': mx.std(grouped_rewards), 'kl': mean_kl, **reward_metrics } return loss, sequence_lengths.sum(), metrics def iterate_grpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): """ Creates batches from dataset entries for GRPO training. Args: dataset: List of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples tokenizer: Tokenizer for processing inputs batch_size: Size of each batch max_seq_length: Maximum sequence length train: Whether this is for training Yields: Tuple containing: - prompts_tokens: List of token sequences for current batch - answers_tokens: List of token sequences - prompts_text: List of prompt strings - answers_text: List of answer strings """ # Verify dataset format if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: raise ValueError("Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples") # Sort by combined length of prompt + answer tokens idx = sorted(range(len(dataset)), key=lambda i: len(dataset[i][0]) + len(dataset[i][1])) if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size} " f"examples but only has {len(dataset)}." ) # Handle distributed training step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") # Create batch indices batch_idx = [ idx[i : i + batch_size : step] for i in range(0, len(idx) - batch_size + 1, batch_size) ] while True: # Shuffle batch indices if training indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) for i in indices: # Get current batch current_batch = [dataset[j] for j in batch_idx[i]] # Extract all components prompts_tokens = [item[0] for item in current_batch] answers_tokens = [item[1] for item in current_batch] prompts_text = [item[2] for item in current_batch] answers_text = [item[3] for item in current_batch] if any(len(p) > max_seq_length for p in prompts_tokens): print( f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " "Long prompts will be truncated." ) yield prompts_tokens, answers_tokens, prompts_text, answers_text if not train: break def evaluate_grpo( model, ref_model, dataset, tokenizer, batch_size, num_batches, beta: float, epsilon: float, group_size: int, max_seq_length, reward_funcs = None, loss: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches ): """ Evaluate model using GRPO loss. Returns: tuple: (average loss, number of tokens, average metrics) """ all_losses = 0 ntokens = 0 all_metrics = None # Initialize metrics dictionary # Create iterator for batches index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) # Iterate through batches for _, batch in zip( index_iterator, iterate_batches( dataset=dataset, tokenizer=tokenizer, batch_size=batch_size, max_seq_length=max_seq_length, ), ): # Calculate loss for current batch losses, toks, metrics = loss( model=model, tokenizer=tokenizer, batch=batch, reward_funcs=reward_funcs, beta=beta, group_size=group_size, epsilon=epsilon, ref_model=ref_model ) # Accumulate losses and tokens all_losses += losses * toks ntokens += toks # Accumulate metrics if all_metrics is None: all_metrics = {k: v * toks for k, v in metrics.items()} else: for k, v in metrics.items(): all_metrics[k] += v * toks # Evaluate accumulated values mx.eval(all_losses, ntokens) # Aggregate across distributed workers all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} # Calculate averages avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} avg_loss = (all_losses / ntokens).item() return avg_loss, ntokens, avg_metrics def train_grpo( model, ref_model, tokenizer, optimizer, train_dataset, val_dataset, reward_funcs = [ r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml ], args: GRPOTrainingArgs = GRPOTrainingArgs(), loss: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches, training_callback: TrainingCallback = None, ): print(f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() rank = world.rank() if world_size > 1: print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) state = [model.state, optimizer.state] def step(batch): # Forward and backward pass (loss, toks, metrics), grad = loss_value_and_grad( model, tokenizer=tokenizer, batch=batch, reward_funcs=reward_funcs, beta=args.beta, group_size=args.group_size, epsilon=args.epsilon, ref_model=ref_model ) # All reduce the gradients if running in distributed mode grad = average_gradients(grad) # Model update optimizer.update(model, grad) return loss, toks, metrics loss_value_and_grad = nn.value_and_grad(model, loss) losses = 0 n_tokens = 0 steps = 0 trained_tokens = 0 accumulated_metrics = { 'rewards': 0, 'rewards_std': 0, 'grouped_rewards': 0, 'grouped_rewards_std': 0, 'kl': 0 } for i in range(len(reward_funcs)): accumulated_metrics[f'reward_func_{i}_mean'] = 0 accumulated_metrics[f'reward_func_{i}_std'] = 0 start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_batches( dataset=train_dataset, tokenizer=tokenizer, batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, ), ): # Report validation loss if needed, the first validation loss # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() val_loss, val_ntokens, val_metrics = evaluate_grpo( model=model, dataset=val_dataset, loss=loss, ref_model=ref_model, reward_funcs=reward_funcs, tokenizer=tokenizer, group_size=args.group_size, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, beta=args.beta, epsilon=args.epsilon, iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop if rank == 0: val_metrics_str = ( f"Val loss {val_loss:.8f}, " f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, " f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, " f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, " f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, " f"Val kl {val_metrics['kl']:.3f}" ) # Add reward function specific metrics for i in range(len(reward_funcs)): val_metrics_str += ( f", Val reward_func_{i}_mean {val_metrics[f'reward_func_{i}_mean']:.3f}, " f"Val reward_func_{i}_std {val_metrics[f'reward_func_{i}_std']:.3f}" ) print( f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s", flush=True, ) if training_callback is not None: training_callback.on_val_loss_report({ "iteration": it, "val_loss": val_loss, **{f"val_{k}": v for k, v in val_metrics.items()}, "val_time": val_time, }) start = time.perf_counter() loss, toks, metrics = step(batch) losses += loss n_tokens += toks steps += 1 for k, v in metrics.items(): accumulated_metrics[k] += v mx.eval(state, losses, n_tokens) if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()} n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: train_metrics_str = ( f"Train loss {train_loss:.8f}, " f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, " f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, " f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, " f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, " f"KL {avg_metrics['kl']:.3f}" ) # Add reward function specific metrics for i in range(len(reward_funcs)): train_metrics_str += ( f", Reward func {i} mean {avg_metrics[f'reward_func_{i}_mean']:.3f}, " f"Reward func {i} std {avg_metrics[f'reward_func_{i}_std']:.3f}" ) print( f"Iter {it}: {train_metrics_str}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " f"Peak mem {peak_mem:.3f} GB", flush=True, ) if training_callback is not None: training_callback.on_train_loss_report({ "iteration": it, "train_loss": train_loss, **{f"train_{k}": v for k, v in avg_metrics.items()}, "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec, "trained_tokens": trained_tokens, "peak_memory": peak_mem, }) losses = 0 n_tokens = 0 steps = 0 start = time.perf_counter() # Save adapter weights if it % args.steps_per_save == 0: adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) checkpoint = ( Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" ) mx.save_safetensors(str(checkpoint), adapter_weights) print( f"Iter {it}: Saved adapter weights to " f"{args.adapter_file} and {checkpoint}." ) # Save final weights adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) print(f"Saved final weights to {args.adapter_file}.")