diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 43f508c3..28aa3420 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -15,6 +15,7 @@ import yaml from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train +from .tuner.grpo_trainer import GRPOTrainingArgs, evaluate_grpo, train_grpo from .tuner.utils import ( build_schedule, linear_to_lora_layers, diff --git a/llms/mlx_lm/tuner/grpo_trainer.py b/llms/mlx_lm/tuner/grpo_trainer.py index b5030ce0..720533b1 100644 --- a/llms/mlx_lm/tuner/grpo_trainer.py +++ b/llms/mlx_lm/tuner/grpo_trainer.py @@ -290,7 +290,7 @@ def evaluate_grpo( return avg_loss, ntokens, avg_metrics -def train( +def train_grpo( model, tokenizer, optimizer, @@ -354,7 +354,7 @@ def train( # is always measured before any training. if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() - val_loss, val_ntokens, val_metrics = evaluate( + val_loss, val_ntokens, val_metrics = evaluate_grpo( model=model, dataset=val_dataset, loss=loss, @@ -458,358 +458,4 @@ def train( # Save final weights adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) - print(f"Saved final weights to {args.adapter_file}.") - - - - - - - - - - - - - -# Copyright © 2024 Apple Inc. - -import glob -import shutil -import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np -from mlx.utils import tree_flatten - -from .trainer import grad_checkpoint, TrainingArgs, TrainingCallback, average_gradients - - -@dataclass -class GRPOTrainingArgs(TrainingArgs): - group_size: int = field( - default=4, - metadata={"help": "Number of response sper prompt."}, - ) - beta: float = field( - default=0.1, metadata={"help": "KL penalty coefficient."} - ) - - -def grpo_loss( - model, - reference_teacher_model, - inputs, - targets, - lengths, - beta=0.2, - group_size=4, - is_reference_free: bool = False - ): - """GRPO loss function compatible with MLX training loop.""" - # Reshape inputs to account for multiple generations per prompt - batch_size = inputs.shape[0] // group_size - - # Get logits from current model - logits = model(inputs).astype(mx.float32) - - # Calculate log probabilities - log_probs = mx.log_softmax(logits[:, :-1, :], axis=-1) - - # Gather actual token probabilities - targets = targets[:, :log_probs.shape[1]] - token_log_probs = mx.take_along_axis( - log_probs, - targets.reshape(*targets.shape, 1), - axis=-1 - ).squeeze(-1) - - # Get reference model log probabilities - if ref_model is None: - with model.disable_adapter(): # Assuming adapter-based fine-tuning - ref_logits = model(inputs).astype(mx.float32) - else: - ref_logits = ref_model(inputs).astype(mx.float32) - - ref_log_probs = mx.log_softmax(ref_logits[:, :-1, :], axis=-1) - ref_token_log_probs = mx.take_along_axis( - ref_log_probs, - targets.reshape(*targets.shape, 1), - axis=-1 - ).squeeze(-1) - - # Calculate KL divergence - kl_div = (mx.exp(ref_token_log_probs - token_log_probs) - - (ref_token_log_probs - token_log_probs) - 1) - - # Calculate rewards (placeholder - implement actual reward calculation) - rewards = mx.random.normal((inputs.shape[0],)) - - # Calculate group advantages - grouped_rewards = rewards.reshape(batch_size, group_size) - means = mx.mean(grouped_rewards, axis=1) - stds = mx.std(grouped_rewards, axis=1) - means = mx.repeat(means.reshape(-1, 1), group_size, axis=1).reshape(-1) - stds = mx.repeat(stds.reshape(-1, 1), group_size, axis=1).reshape(-1) - advantages = (rewards - means) / (stds + 1e-8) - - # Calculate policy gradient loss - policy_ratio = mx.exp(token_log_probs - mx.stop_gradient(token_log_probs)) - pg_loss = -policy_ratio * advantages.reshape(-1, 1) - - # Create length mask - length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1) - - # Combine losses - loss = (pg_loss + beta * kl_div) * length_mask - ntoks = length_mask.sum() - loss = loss.sum() / ntoks - - return loss, ntoks - - -def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - # Sort by length: - idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) - if len(dataset) < batch_size: - raise ValueError( - f"Dataset must have at least batch_size={batch_size}" - f" examples but only has {len(dataset)}." - ) - - # If running in distributed mode (N machines) then each one should skip N-1 - # samples - step = mx.distributed.init().size() - if batch_size % step != 0: - raise ValueError("The batch size must be divisible by the number of workers") - - # Make the batches: - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] - - while True: - indices = np.random.permutation(len(batch_idx)) - for i in indices: - batch = [dataset[j] for j in batch_idx[i]] - lengths = [len(x) for x in batch] - if max(lengths) > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " - "Consider pre-splitting your data to save memory." - ) - - # Pad to the nearest multiple of 8 or the maximum length - pad_to = 8 - max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - - for j in range(batch_size // step): - truncated_length = min(lengths[j], max_seq_length) - batch_arr[j, :truncated_length] = batch[j][:truncated_length] - lengths[j] = ( - truncated_length # Update lengths to match truncated lengths - ) - batch = mx.array(batch_arr) - - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) - - if not train: - break - - -def evaluate( - model, - dataset, - tokenizer, - batch_size, - num_batches, - max_seq_length=2048, - loss: callable = grpo_loss, - iterate_batches: callable = iterate_batches, -): - all_losses = 0 - ntokens = 0 - - index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - - for _, batch in zip( - index_iterator, - iterate_batches( - dataset=dataset, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_length=max_seq_length, - ), - ): - losses, toks = loss(model, *batch) - all_losses += losses * toks - ntokens += toks - mx.eval(all_losses, ntokens) - - all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu) - ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu) - - return (all_losses / ntokens).item() - - -def train( - model, - tokenizer, - optimizer, - train_dataset, - val_dataset, - args: GRPOTrainingArgs = GRPOTrainingArgs(), - loss: callable = grpo_loss, - iterate_batches: callable = iterate_batches, - training_callback: TrainingCallback = None, -): - print(f"Starting training..., iters: {args.iters}") - world = mx.distributed.init() - world_size = world.size() - rank = world.rank() - if world_size > 1: - print(f"Node {rank} of {world_size}") - - if args.grad_checkpoint: - grad_checkpoint(model.layers[0]) - - state = [model.state, optimizer.state] - - def step(batch): - # Forward and backward pass - (lvalue, toks), grad = loss_value_and_grad(model, *batch) - - # All reduce the gradients if running in distributed mode - grad = average_gradients(grad) - - # Model update - optimizer.update(model, grad) - - return lvalue, toks - - loss_value_and_grad = nn.value_and_grad(model, loss) - - # Save initial model weights as reference - ref_weights = {k: v.copy() for k, v in model.parameters().items()} - - losses = 0 - n_tokens = 0 - steps = 0 - trained_tokens = 0 - # Main training loop - start = time.perf_counter() - for it, batch in zip( - range(1, args.iters + 1), - iterate_batches( - dataset=train_dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - max_seq_length=args.max_seq_length, - train=True, - ), - ): - # Report validation loss if needed, the first validation loss - # is always measured before any training. - if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - stop = time.perf_counter() - val_loss = evaluate( - model=model, - dataset=val_dataset, - loss=loss, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.val_batches, - max_seq_length=args.max_seq_length, - iterate_batches=iterate_batches, - ) - val_time = time.perf_counter() - stop - if rank == 0: - print( - f"Iter {it}: " - f"Val loss {val_loss:.3f}, " - f"Val took {val_time:.3f}s", - flush=True, - ) - - if training_callback is not None: - val_info = { - "iteration": it, - "val_loss": val_loss, - "val_time": val_time, - } - training_callback.on_val_loss_report(val_info) - - start = time.perf_counter() - - lvalue, toks = step(batch) - losses += lvalue - n_tokens += toks - steps += 1 - mx.eval(state, losses, n_tokens) - - # Report training loss if needed - if it % args.steps_per_report == 0 or it == args.iters: - stop = time.perf_counter() - - train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() - train_loss /= steps * mx.distributed.init().size() - n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() - learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / (stop - start) - tokens_sec = float(n_tokens) / (stop - start) - trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 1e9 - if rank == 0: - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB", - flush=True, - ) - - if training_callback is not None: - train_info = { - "iteration": it, - "train_loss": train_loss, - "learning_rate": learning_rate, - "iterations_per_second": it_sec, - "tokens_per_second": tokens_sec, - "trained_tokens": trained_tokens, - "peak_memory": peak_mem, - } - training_callback.on_train_loss_report(train_info) - - losses = 0 - n_tokens = 0 - steps = 0 - start = time.perf_counter() - - # Save adapter weights - if it % args.steps_per_save == 0: - adapter_weights = dict(tree_flatten(model.trainable_parameters())) - mx.save_safetensors(str(args.adapter_file), adapter_weights) - checkpoint = ( - Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" - ) - mx.save_safetensors(str(checkpoint), adapter_weights) - print( - f"Iter {it}: Saved adapter weights to " - f"{args.adapter_file} and {checkpoint}." - ) - - # Save final weights - adapter_weights = dict(tree_flatten(model.trainable_parameters())) - mx.save_safetensors(str(args.adapter_file), adapter_weights) - print(f"Saved final weights to {args.adapter_file}.") + print(f"Saved final weights to {args.adapter_file}.") \ No newline at end of file