# Copyright Ā© 2024 Apple Inc. import time from dataclasses import dataclass, field from pathlib import Path from typing import Any, Callable, Generator, List, Optional, Tuple import mlx.core as mx import mlx.nn as nn import numpy as np from mlx.utils import tree_flatten from ..models import cache from ..utils import generation_stream from .grpo_reward_functions import ( RewardFunctions, r1_accuracy_reward_func, r1_count_xml, r1_extract_xml_answer, r1_int_reward_func, r1_soft_format_reward_func, r1_strict_format_reward_func, ) from .trainer import TrainingArgs, TrainingCallback, average_gradients, grad_checkpoint @dataclass class GRPOTrainingArgs(TrainingArgs): group_size: int = field( default=4, metadata={"help": "Number of responses per prompt."}, ) beta: float = field(default=0.1, metadata={"help": "KL penalty coefficient."}) epsilon: float = field( default=1e-4, metadata={"help": "The Epsilon for numerical stability."} ) max_completion_length: int = field( default=512, metadata={"help": "Number of Generations."} ) reference_model_path: str = field( default=None, metadata={ "help": "Path to reference model weights. If None, uses the same model." }, ) temperature: float = field( default=1.0, metadata={ "help": "Temperature for sampling. The higher the temperature, the more random the completions." }, ) reward_weights: Optional[List[float]] = field( default=None, metadata={ "help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`." }, ) def get_per_token_logps(model: nn.Module, inputs, lengths): logits = model(inputs).astype(mx.float16) logits = logits[:, :-1, :] targets = inputs[:, 1:] per_token_logps = [] for i in range(logits.shape[0]): seq_len = int(lengths[i]) - 1 seq_logits = logits[i, :seq_len] seq_targets = targets[i, :seq_len] log_probs = nn.log_softmax(seq_logits, axis=-1) token_log_probs = mx.take_along_axis( log_probs, seq_targets.reshape(seq_len, 1), axis=-1 ).squeeze(-1) per_token_logps.append(token_log_probs) mx.eval(logits) return per_token_logps def generate_step( prompt: mx.array, model: nn.Module, max_tokens: int = 256, sampler: Optional[Callable] = None, logits_processors: Optional[List[Callable]] = None, max_kv_size: Optional[int] = None, prompt_cache: Optional[Any] = None, ) -> Generator[Tuple[mx.array, mx.array], None, None]: tokens = None y = prompt if prompt_cache is None: prompt_cache = cache.make_prompt_cache(model, max_kv_size=max_kv_size) def _step(y): with mx.stream(generation_stream): logits = model(y[None], cache=prompt_cache) logits = logits[:, -1, :] if logits_processors: nonlocal tokens tokens = mx.concat([tokens, y]) if tokens is not None else y for processor in logits_processors: logits = processor(tokens, logits) logprobs = logits - mx.logsumexp(logits, keepdims=True) next_token = sampler(logprobs) return mx.stop_gradient(next_token), mx.stop_gradient(logprobs.squeeze(0)) try: with mx.stream(generation_stream): y, logprobs = _step(y) mx.eval(y, logprobs) for n in range(max_tokens): yield y.item(), logprobs next_y, next_logprobs = _step(y) mx.eval(next_y, next_logprobs) y, logprobs = next_y, next_logprobs if (n + 1) % 32 == 0: mx.metal.clear_cache() finally: mx.metal.clear_cache() def generate_grpo( model: nn.Module, tokenizer, prompt_tokens, max_tokens: int, group_size: int, temperature: float, batch_size: int, end_token: str = "" ): try: end_sequence = mx.array(tokenizer.encode(end_token)) total_samples = len(prompt_tokens) all_completions = [] all_completion_texts = [] batch_indices = [] def temp_sampler(logits): return mx.random.categorical(logits / temperature) for i in range(0, total_samples, batch_size): current_batch_size = min(batch_size, total_samples - i) batch_prompts = prompt_tokens[i : i + current_batch_size] max_prompt_len = max(len(p) for p in batch_prompts) padded_prompts = [] for prompt in batch_prompts: padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt)) padded_prompts.append(prompt + padding) prompt_tensor = mx.stop_gradient(mx.array(padded_prompts)) if len(prompt_tensor.shape) == 1: prompt_tensor = prompt_tensor[None, :] if prompt_tensor.shape[1] == 0: continue expanded_prompts = mx.repeat(prompt_tensor, group_size, axis=0) batch_results = [] total_prompt_samples = expanded_prompts.shape[0] for prompt_idx in range(total_prompt_samples): current_tokens = [] prompt_cache = cache.make_prompt_cache(model) for token, _ in generate_step( expanded_prompts[prompt_idx], model, max_tokens=max_tokens, sampler=temp_sampler, prompt_cache=prompt_cache, ): if token == tokenizer.eos_token_id: break current_tokens.append(token) if len(current_tokens) >= len(end_sequence) and mx.array_equal( mx.array(current_tokens[-len(end_sequence):]), end_sequence ): break if current_tokens: batch_results.append(mx.array(current_tokens)) if batch_results: for j, completion_ids in enumerate(batch_results): prompt_idx = i + (j // group_size) if prompt_idx < total_samples: batch_indices.append(prompt_idx) completion_text = tokenizer.decode(completion_ids.tolist()) all_completions.append(mx.stop_gradient(completion_ids)) all_completion_texts.append(completion_text) mx.metal.clear_cache() finally: mx.metal.clear_cache() return all_completions, all_completion_texts, batch_indices def grpo_loss( model, ref_model, tokenizer, batch, completions=None, completion_texts=None, batch_indices=None, reward_funcs: Optional[List[RewardFunctions]] = None, beta: float = 0.1, group_size: int = 4, epsilon: float = 1e-4, max_tokens: int = 64, temperature: float = 0.8, reward_weights: Optional[List[float]] = None, batch_size: int = 1, is_validation: bool = False ): prompt_tokens, _, prompt_text, answer_text = batch if completions is not None and completion_texts is not None and batch_indices is not None: all_completions = completions all_completion_texts = completion_texts batch_indices = batch_indices else: all_completions, all_completion_texts, batch_indices = generate_grpo( model=model, tokenizer=tokenizer, prompt_tokens=prompt_tokens, max_tokens=max_tokens, group_size=group_size, temperature=temperature, batch_size=batch_size ) if not all_completions: raise ValueError( "No completions were generated. Please check your model and inputs." ) expanded_answers = [] expanded_prompts = [] unique_prompt_indices = sorted(set(batch_indices)) grouped_completions = {idx: [] for idx in unique_prompt_indices} for i, completion_idx in enumerate(batch_indices): grouped_completions[completion_idx].append(i) ordered_completions = [] ordered_completion_texts = [] ordered_batch_indices = [] for prompt_idx in unique_prompt_indices: completion_indices = grouped_completions[prompt_idx] for idx in completion_indices: ordered_completions.append(all_completions[idx]) ordered_completion_texts.append(all_completion_texts[idx]) ordered_batch_indices.append(prompt_idx) expanded_prompts.append(prompt_text[prompt_idx]) expanded_answers.append(answer_text[prompt_idx]) all_completions = ordered_completions all_completion_texts = ordered_completion_texts batch_indices = ordered_batch_indices max_length = max(ids.shape[0] for ids in all_completions) padded_completions = [] attention_masks = [] for completion_ids in all_completions: completion_tensor = mx.array(completion_ids.tolist()) padding_length = max_length - completion_tensor.shape[0] if padding_length > 0: padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype) padded_ids = mx.concatenate([completion_tensor, padding]) mask = mx.concatenate( [mx.ones_like(completion_tensor), mx.zeros_like(padding)] ) else: padded_ids = completion_tensor mask = mx.ones_like(completion_tensor) padded_completions.append(padded_ids) attention_masks.append(mask) inputs = mx.stack(padded_completions) attention_mask = mx.stack(attention_masks) lengths = attention_mask.sum(axis=1) token_log_probs = get_per_token_logps(model, inputs, lengths) mx.eval(token_log_probs) if ref_model is None: ref_token_log_probs = token_log_probs else: ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths) mx.eval(ref_token_log_probs) max_len = max(x.shape[0] for x in token_log_probs) padded_log_probs = [] padded_ref_log_probs = [] for i in range(len(token_log_probs)): seq_len = token_log_probs[i].shape[0] padding = mx.zeros((max_len - seq_len,)) padded_log_probs.append(mx.concatenate([token_log_probs[i], padding])) padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding])) token_log_probs = mx.stack(padded_log_probs) ref_token_log_probs = mx.stack(padded_ref_log_probs) all_func_rewards = [] for reward_func in reward_funcs: func_rewards = mx.array( reward_func( prompts=expanded_prompts, completions=all_completion_texts, answer=expanded_answers, ) ) all_func_rewards.append(func_rewards) rewards = mx.stack(all_func_rewards, axis=1) if reward_weights is not None: if len(reward_weights) != len(reward_funcs): raise ValueError( f"Number of reward weights ({len(reward_weights)}) must match number of reward " f"functions ({len(reward_funcs)})" ) reward_weights = mx.array(reward_weights, dtype=mx.float32) else: reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32) rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1) num_unique_prompts = len(unique_prompt_indices) rewards_by_prompt = [[] for _ in range(num_unique_prompts)] for i, prompt_idx in enumerate(batch_indices): prompt_position = unique_prompt_indices.index(prompt_idx) rewards_by_prompt[prompt_position].append(rewards[i]) advantages = mx.zeros_like(rewards) for i, prompt_rewards in enumerate(rewards_by_prompt): if len(prompt_rewards) > 1: prompt_rewards = mx.array(prompt_rewards) mean_reward = mx.mean(prompt_rewards) std_reward = mx.std(prompt_rewards) indices = [ j for j, idx in enumerate(batch_indices) if idx == unique_prompt_indices[i] ] for j, idx in enumerate(indices): advantages[idx] = (prompt_rewards[j] - mean_reward) / ( std_reward + epsilon ) else: idx = batch_indices.index(unique_prompt_indices[i]) advantages[idx] = 0.0 # Compute KL divergence using Schulman's approximator kl_div = ( mx.exp(ref_token_log_probs - token_log_probs) - (ref_token_log_probs - token_log_probs) - 1 ) # Create mask for valid tokens length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) # Compute policy ratio policy_ratio = mx.exp( mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs)) ) # Compute per-token loss per_token_loss = -( (policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask ) # Average over tokens sequence_sums = per_token_loss.sum(axis=1) sequence_lengths = length_mask.sum(axis=1) loss = (sequence_sums / sequence_lengths).mean() # Calculate mean KL divergence for metrics mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean() # Collect reward metrics reward_metrics = {} for i, reward_func in enumerate(reward_funcs): func_name = reward_func.__name__ func_rewards = mx.array( reward_func( prompts=expanded_prompts, completions=all_completion_texts, answer=expanded_answers, ) ) reward_metrics[f"{func_name}_mean"] = mx.mean(func_rewards) reward_metrics[f"{func_name}_std"] = mx.std(func_rewards) grouped_rewards_mean = mx.array( [mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt] ) grouped_rewards_std = mx.array( [ mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1) for rewards in rewards_by_prompt ] ) metrics = { "total_rewards_mean": mx.mean(rewards), "total_rewards_std": mx.std(rewards), "grouped_rewards_mean": mx.mean(grouped_rewards_mean), "grouped_rewards_std": mx.mean(grouped_rewards_std), "kl": mean_kl, **reward_metrics, } if is_validation and all_completion_texts: print("\n=== Validation Sample Details ===") # Print the input context (prompt) last_prompt_idx = batch_indices[-1] if batch_indices else 0 if last_prompt_idx < len(prompt_text): print(f"\nšŸ“‹ Raw Prompt:\n{prompt_text[last_prompt_idx]}") print("\n" + "=" * 10 + "\n") # Get the actual tokenized prompt that was fed to the model if last_prompt_idx < len(prompt_tokens): actual_prompt = tokenizer.decode(prompt_tokens[last_prompt_idx]) print(f"\nšŸ”„ Model Input:\n{actual_prompt}") print("\n" + "=" * 10 + "\n") print(f"\nšŸ“ Generation:\n{all_completion_texts[-1]}") print("\n" + "=" * 10 + "\n") # Make sure we have a valid index for answer_text if last_prompt_idx < len(answer_text): print(f"\nāœ… Answer:\n{answer_text[last_prompt_idx]}") print("\n" + "=" * 10 + "\n") # Only try to extract if r1_extract_xml_answer is defined if "r1_extract_xml_answer" in globals(): print( f"\nšŸ” Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}" ) print("\n" + "=" * 35 + "\n") mx.metal.clear_cache() return loss, sequence_lengths.sum(), metrics def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False): if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4: raise ValueError( "Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples" ) def length_key(i): return len(dataset[i][0]) + len(dataset[i][1]) idx = sorted(range(len(dataset)), key=length_key) if len(dataset) < batch_size: raise ValueError( f"Dataset must have at least batch_size={batch_size} " f"examples but only has {len(dataset)}." ) step = mx.distributed.init().size() if batch_size % step != 0: raise ValueError("The batch size must be divisible by the number of workers") def batch_index_generator(): for i in range(0, len(idx) - batch_size + 1, batch_size): yield idx[i : i + batch_size : step] while True: indices = ( np.random.permutation(list(batch_index_generator())) if train else batch_index_generator() ) for batch_idx in indices: current_batch = [dataset[j] for j in batch_idx] prompts_tokens = [item[0] for item in current_batch] answers_tokens = [item[1] for item in current_batch] prompts_text = [item[2] for item in current_batch] answers_text = [item[3] for item in current_batch] if any(len(p) > max_seq_length for p in prompts_tokens): print( f"[WARNING] Some prompts are longer than {max_seq_length} tokens. " "Long prompts will be truncated." ) yield prompts_tokens, answers_tokens, prompts_text, answers_text if not train: break def evaluate_grpo( model: nn.Module, ref_model: Optional[nn.Module], dataset, tokenizer, batch_size, num_batches, beta: float, epsilon: float, group_size: int, max_seq_length: int, max_tokens: int, temperature: float, reward_funcs: Optional[List[RewardFunctions]] = [ r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, ], loss_fn: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches, ): all_losses = 0 ntokens = 0 all_metrics = None index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) for _, batch in zip( index_iterator, iterate_batches( dataset=dataset, batch_size=batch_size, max_seq_length=max_seq_length, ), ): losses, toks, metrics = loss_fn( model=model, tokenizer=tokenizer, batch=batch, reward_funcs=reward_funcs, beta=beta, group_size=group_size, epsilon=epsilon, ref_model=ref_model, temperature=temperature, max_tokens=max_tokens, is_validation=True ) all_losses += losses * toks ntokens += toks if all_metrics is None: all_metrics = {k: v * toks for k, v in metrics.items()} else: for k, v in metrics.items(): all_metrics[k] += v * toks mx.eval(all_losses, ntokens) all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} avg_loss = (all_losses / ntokens).item() return avg_loss, ntokens, avg_metrics def train_grpo( model: nn.Module, ref_model: Optional[nn.Module], tokenizer, optimizer, train_dataset, val_dataset, reward_funcs: Optional[List[RewardFunctions]] = [ r1_accuracy_reward_func, r1_int_reward_func, r1_strict_format_reward_func, r1_soft_format_reward_func, r1_count_xml, ], args: GRPOTrainingArgs = GRPOTrainingArgs(), loss_fn: callable = grpo_loss, iterate_batches: callable = iterate_grpo_batches, training_callback: TrainingCallback = None, ): print( f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}" ) world = mx.distributed.init() world_size = world.size() rank = world.rank() if world_size > 1: print(f"Node {rank} of {world_size}") if args.grad_checkpoint: grad_checkpoint(model.layers[0]) state = [model.state, optimizer.state] def step(batch): prompt_tokens, targets, prompt_lens, target_lens = batch all_completions, all_completion_texts, batch_indices = generate_grpo( model=model, tokenizer=tokenizer, prompt_tokens=prompt_tokens, max_tokens=args.max_completion_length, group_size=args.group_size, temperature=args.temperature, batch_size=args.batch_size ) (loss, toks, metrics), grad = loss_value_and_grad( model, tokenizer=tokenizer, batch=(prompt_tokens, targets, prompt_lens, target_lens), completions=all_completions, completion_texts=all_completion_texts, batch_indices=batch_indices, reward_funcs=reward_funcs, beta=args.beta, group_size=args.group_size, epsilon=args.epsilon, ref_model=ref_model, ) grad = average_gradients(grad) optimizer.update(model, grad) return loss, toks, metrics loss_value_and_grad = nn.value_and_grad(model, loss_fn) losses = 0 n_tokens = 0 steps = 0 trained_tokens = 0 accumulated_metrics = { "total_rewards_mean": 0, "total_rewards_std": 0, "grouped_rewards_mean": 0, "grouped_rewards_std": 0, "kl": 0, } for reward_func in reward_funcs: func_name = reward_func.__name__ accumulated_metrics[f"{func_name}_mean"] = 0 accumulated_metrics[f"{func_name}_std"] = 0 start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), iterate_batches( dataset=train_dataset, batch_size=args.batch_size, max_seq_length=args.max_seq_length, train=True, ), ): if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() val_loss, val_ntokens, val_metrics = evaluate_grpo( model=model, dataset=val_dataset, loss_fn=loss_fn, ref_model=ref_model, reward_funcs=reward_funcs, tokenizer=tokenizer, group_size=args.group_size, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, max_tokens=args.max_completion_length, beta=args.beta, epsilon=args.epsilon, temperature=args.temperature, iterate_batches=iterate_batches, ) val_time = time.perf_counter() - stop if rank == 0: val_metrics_str = ( f"Val loss {val_loss:.3f}, " f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, " f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, " f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, " f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, " f"Val kl {val_metrics['kl']:.3f}" ) for i, reward_func in enumerate(reward_funcs): val_metrics_str += ( f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, " f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}" ) print( f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s", flush=True, ) if training_callback is not None: training_callback.on_val_loss_report( { "iteration": it, "val_loss": val_loss, **{f"val_{k}": v for k, v in val_metrics.items()}, "val_time": val_time, } ) start = time.perf_counter() loss, toks, metrics = step(batch) losses += loss n_tokens += toks steps += 1 mx.metal.clear_cache() for k, v in metrics.items(): accumulated_metrics[k] += v mx.eval(state, losses, n_tokens) if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss /= steps * mx.distributed.init().size() avg_metrics = { k: v / (steps * world_size) for k, v in accumulated_metrics.items() } n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) tokens_sec = float(n_tokens) / (stop - start) trained_tokens += n_tokens peak_mem = mx.metal.get_peak_memory() / 1e9 if rank == 0: train_metrics_str = ( f"Train loss {train_loss:.3f}, " f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, " f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, " f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, " f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, " f"KL {avg_metrics['kl']:.3f}" ) for i, reward_func in enumerate(reward_funcs): func_name = reward_func.__name__ train_metrics_str += ( f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, " f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}" ) print( f"Iter {it}: {train_metrics_str}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " f"Peak mem {peak_mem:.3f} GB", flush=True, ) if training_callback is not None: training_callback.on_train_loss_report( { "iteration": it, "train_loss": train_loss, **{f"train_{k}": v for k, v in avg_metrics.items()}, "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec, "trained_tokens": trained_tokens, "peak_memory": peak_mem, } ) losses = 0 n_tokens = 0 steps = 0 start = time.perf_counter() if it % args.steps_per_save == 0: adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) checkpoint = ( Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" ) mx.save_safetensors(str(checkpoint), adapter_weights) print( f"Iter {it}: Saved adapter weights to " f"{args.adapter_file} and {checkpoint}." ) adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) print(f"Saved final weights to {args.adapter_file}.")