diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py new file mode 100644 index 00000000..ae4749c4 --- /dev/null +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -0,0 +1,332 @@ +# Copyright © 2024 Apple Inc. + +import glob +import shutil +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Union + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.utils import tree_flatten + +from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients + + +@dataclass +class GRPOTrainingArgs(TrainingArgs): + group_size: int = field( + default=4, + metadata={"help": "Number of response sper prompt."}, + ) + beta: float = field( + default=0.1, metadata={"help": "KL penalty coefficient."} + ) + + +def grpo_loss(model, inputs, targets, lengths, ref_model=None, beta=0.2, group_size=4): + """GRPO loss function compatible with MLX training loop.""" + # Reshape inputs to account for multiple generations per prompt + batch_size = inputs.shape[0] // group_size + + # Get logits from current model + logits = model(inputs).astype(mx.float32) + + # Calculate log probabilities + log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1) + + # Gather actual token probabilities + targets = targets[:, :log_probs.shape[1]] + token_log_probs = mx.take_along_axis( + log_probs, + targets.reshape(*targets.shape, 1), + axis=-1 + ).squeeze(-1) + + # Get reference model log probabilities + if ref_model is None: + with model.disable_adapter(): # Assuming adapter-based fine-tuning + ref_logits = model(inputs).astype(mx.float32) + else: + ref_logits = ref_model(inputs).astype(mx.float32) + + ref_log_probs = mx.log_softmax(ref_logits[:, :-1, :], axis=-1) + ref_token_log_probs = mx.take_along_axis( + ref_log_probs, + targets.reshape(*targets.shape, 1), + axis=-1 + ).squeeze(-1) + + # Calculate KL divergence + kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - + (ref_token_log_probs - token_log_probs) - 1) + + # Calculate rewards (placeholder - implement actual reward calculation) + rewards = mx.random.normal((inputs.shape[0],)) + + # Calculate group advantages + grouped_rewards = rewards.reshape(batch_size, group_size) + means = mx.mean(grouped_rewards, axis=1) + stds = mx.std(grouped_rewards, axis=1) + means = mx.repeat(means.reshape(-1, 1), group_size, axis=1).reshape(-1) + stds = mx.repeat(stds.reshape(-1, 1), group_size, axis=1).reshape(-1) + advantages = (rewards - means) / (stds + 1e-8) + + # Calculate policy gradient loss + policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) + pg_loss = -policy_ratio * advantages.reshape(-1, 1) + + # Create length mask + length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) + + # Combine losses + loss = (pg_loss + beta * kl_div) * length_mask + ntoks = length_mask.sum() + loss = loss.sum() / ntoks + + return loss, ntoks + + +def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + # Sort by length: + idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + if len(dataset) < batch_size: + raise ValueError( + f"Dataset must have at least batch_size={batch_size}" + f" examples but only has {len(dataset)}." + ) + + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + + # Make the batches: + batch_idx = [ + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) + ] + + while True: + indices = np.random.permutation(len(batch_idx)) + for i in indices: + batch = [dataset[j] for j in batch_idx[i]] + lengths = [len(x) for x in batch] + if max(lengths) > max_seq_length: + print( + f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " + f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " + "Consider pre-splitting your data to save memory." + ) + + # Pad to the nearest multiple of 8 or the maximum length + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) + + for j in range(batch_size // step): + truncated_length = min(lengths[j], max_seq_length) + batch_arr[j, :truncated_length] = batch[j][:truncated_length] + lengths[j] = ( + truncated_length # Update lengths to match truncated lengths + ) + batch = mx.array(batch_arr) + + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + + if not train: + break + + +def evaluate( + model, + dataset, + tokenizer, + batch_size, + num_batches, + max_seq_length=2048, + loss: callable = grpo_loss, + iterate_batches: callable = iterate_batches, +): + all_losses = 0 + ntokens = 0 + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for _, batch in zip( + index_iterator, + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + losses, toks = loss(model, *batch) + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) + + all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) + ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) + + return (all_losses / ntokens).item() + + +def train( + model, + tokenizer, + optimizer, + train_dataset, + val_dataset, + args: GRPOTrainingArgs = GRPOTrainingArgs(), + loss: callable = grpo_loss, + iterate_batches: callable = iterate_batches, + training_callback: TrainingCallback = None, +): + print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") + + if args.grad_checkpoint: + grad_checkpoint(model.layers[0]) + + state = [model.state, optimizer.state] + + def step(batch): + # Forward and backward pass + (lvalue, toks), grad = loss_value_and_grad(model, *batch) + + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + + # Model update + optimizer.update(model, grad) + + return lvalue, toks + + loss_value_and_grad = nn.value_and_grad(model, loss) + + # Save initial model weights as reference + ref_weights = {k: v.copy() for k, v in model.parameters().items()} + + losses = 0 + n_tokens = 0 + steps = 0 + trained_tokens = 0 + # Main training loop + start = time.perf_counter() + for it, batch in zip( + range(1, args.iters + 1), + iterate_batches( + dataset=train_dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + train=True, + ), + ): + # Report validation loss if needed, the first validation loss + # is always measured before any training. + if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: + stop = time.perf_counter() + val_loss = evaluate( + model=model, + dataset=val_dataset, + loss=loss, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.val_batches, + max_seq_length=args.max_seq_length, + iterate_batches=iterate_batches, + ) + val_time = time.perf_counter() - stop + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) + + if training_callback is not None: + val_info = { + "iteration": it, + "val_loss": val_loss, + "val_time": val_time, + } + training_callback.on_val_loss_report(val_info) + + start = time.perf_counter() + + lvalue, toks = step(batch) + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) + + # Report training loss if needed + if it % args.steps_per_report == 0 or it == args.iters: + stop = time.perf_counter() + + train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() + learning_rate = optimizer.learning_rate.item() + it_sec = args.steps_per_report / (stop - start) + tokens_sec = float(n_tokens) / (stop - start) + trained_tokens += n_tokens + peak_mem = mx.metal.get_peak_memory() / 1e9 + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) + + if training_callback is not None: + train_info = { + "iteration": it, + "train_loss": train_loss, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + "peak_memory": peak_mem, + } + training_callback.on_train_loss_report(train_info) + + losses = 0 + n_tokens = 0 + steps = 0 + start = time.perf_counter() + + # Save adapter weights + if it % args.steps_per_save == 0: + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + checkpoint = ( + Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" + ) + mx.save_safetensors(str(checkpoint), adapter_weights) + print( + f"Iter {it}: Saved adapter weights to " + f"{args.adapter_file} and {checkpoint}." + ) + + # Save final weights + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + print(f"Saved final weights to {args.adapter_file}.")