diff --git a/llms/mlx_lm/lora.py b/llms/mlx_lm/lora.py index 35c20274..9f5427a9 100644 --- a/llms/mlx_lm/lora.py +++ b/llms/mlx_lm/lora.py @@ -15,7 +15,6 @@ import yaml from .tokenizer_utils import TokenizerWrapper from .tuner.datasets import load_dataset from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train -from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo from .tuner.orpo_trainer import ORPOTrainingArgs, evaluate_orpo, train_orpo from .tuner.utils import ( build_schedule, @@ -176,7 +175,7 @@ def build_parser(): default=None, ) parser.add_argument("--beta", type=float) - parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpop"]) + parser.add_argument("--dpo-loss-type", type=str, choices=["sigmoid", "hinge", "ipo", "dpo"]) parser.add_argument("--is-reference-free", action="store_true") parser.add_argument("--delta", type=float) parser.add_argument("--reference-model-path", type=str) @@ -229,40 +228,7 @@ def train_model( ) # Train model based on training mode - if args.training_mode == "dpo": - training_args = DPOTrainingArgs( - batch_size=args.batch_size, - iters=args.iters, - val_batches=args.val_batches, - steps_per_report=args.steps_per_report, - steps_per_eval=args.steps_per_eval, - steps_per_save=args.save_every, - adapter_file=adapter_file, - max_seq_length=args.max_seq_length, - grad_checkpoint=args.grad_checkpoint, - beta=args.beta, - loss_type=args.dpo_loss_type, - is_reference_free=args.is_reference_free, - delta=args.delta, - reference_model_path=args.reference_model_path, - ) - - if args.reference_model_path: - reference_model, _ = load(args.reference_model_path) - else: - reference_model, _ = load(args.model) - - train_dpo( - model=model, - reference_model=reference_model.freeze(), - tokenizer=tokenizer, - optimizer=opt, - train_dataset=train_set, - val_dataset=valid_set, - args=training_args, - training_callback=training_callback, - ) - elif args.training_mode == "orpo": + if args.training_mode == "orpo": training_args = ORPOTrainingArgs( batch_size=args.batch_size, iters=args.iters, @@ -273,8 +239,7 @@ def train_model( adapter_file=adapter_file, max_seq_length=args.max_seq_length, grad_checkpoint=args.grad_checkpoint, - beta=args.beta, - reward_scaling=args.reward_scaling, + beta=args.beta ) train_orpo( @@ -284,7 +249,7 @@ def train_model( train_dataset=train_set, val_dataset=valid_set, args=training_args, - training_callback=training_callback, + training_callback=training_callback ) else: training_args = TrainingArgs( @@ -313,26 +278,7 @@ def train_model( def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set): model.eval() - if args.training_mode == "dpo": - if args.reference_model_path: - reference_model, _ = load(args.reference_model_path) - else: - reference_model = model - - test_loss, test_rewards = evaluate_dpo( - model=model, - reference_model=reference_model, - dataset=test_set, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.test_batches, - max_seq_length=args.max_seq_length, - beta=args.beta, - delta=args.delta, - loss_type=args.dpo_loss_type, - ) - print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") - elif args.training_mode == "orpo": + if args.training_mode == "orpo": test_loss, test_rewards = evaluate_orpo( model=model, dataset=test_set, @@ -340,8 +286,7 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set batch_size=args.batch_size, num_batches=args.test_batches, max_seq_length=args.max_seq_length, - beta=args.beta, - reward_scaling=args.reward_scaling, + beta=args.beta ) print(f"Test loss {test_loss:.8f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}") else: diff --git a/llms/mlx_lm/tuner/datasets.py b/llms/mlx_lm/tuner/datasets.py index f0fe45a2..0914c6b7 100644 --- a/llms/mlx_lm/tuner/datasets.py +++ b/llms/mlx_lm/tuner/datasets.py @@ -4,70 +4,47 @@ from typing import Dict, List, Optional from transformers import PreTrainedTokenizer - -class DPODataset: - """ - A dataset for DPO (Direct Preference Optimization) training that handles - prompt-chosen-rejected triplets with optional scores in the format: - {"prompt": ..., "chosen": ..., "rejected": ..., "score_chosen": ..., "score_rejected": ...} - """ - - def __init__( - self, - data: List[Dict[str, str]], - tokenizer: PreTrainedTokenizer, - prompt_key: str = "prompt", - chosen_key: str = "chosen", - rejected_key: str = "rejected", - score_chosen_key: str = "score_chosen", - score_rejected_key: str = "score_rejected", - ): +class ORPODataset: + def __init__( + self, + data: List[Dict[str, str]], + tokenizer: PreTrainedTokenizer, + prompt_key: str = "prompt", + chosen_key: str = "chosen", + rejected_key: str = "rejected", + preference_score_key: str = "preference_score" + ): self._chosen_data = [] self._rejected_data = [] self._scores = [] - + for d in data: - # Process the text data - chosen_text = tokenizer.apply_chat_template( - [ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[chosen_key]}, - ], - ) - rejected_text = tokenizer.apply_chat_template( - [ - {"role": "user", "content": d[prompt_key]}, - {"role": "assistant", "content": d[rejected_key]}, - ], - ) - + chosen_text = tokenizer.apply_chat_template([ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[chosen_key]}, + ]) + rejected_text = tokenizer.apply_chat_template([ + {"role": "user", "content": d[prompt_key]}, + {"role": "assistant", "content": d[rejected_key]}, + ]) + self._chosen_data.append(chosen_text) self._rejected_data.append(rejected_text) - - # Handle scores if they exist - if score_chosen_key in d and score_rejected_key in d: - chosen_score = float(d[score_chosen_key]) - rejected_score = float(d[score_rejected_key]) - - # Normalize scores to [0, 1] range - score_diff = chosen_score - rejected_score - max_diff = max(abs(score_diff), 1.0) # Avoid division by zero - normalized_score = (score_diff / max_diff + 1) / 2 - - self._scores.append(normalized_score) + + if preference_score_key in d: + self._scores.append(float(d[preference_score_key])) else: - # Default to binary preference (1.0) if no scores provided self._scores.append(1.0) - - def __getitem__(self, idx: int): - return { - "chosen": self._chosen_data[idx], - "rejected": self._rejected_data[idx], - "preference_score": self._scores[idx] - } - - def __len__(self): - return len(self._chosen_data) + + def __getitem__(self, idx: int): + return { + "chosen": self._chosen_data[idx], + "rejected": self._rejected_data[idx], + "preference_score": self._scores[idx] + } + + def __len__(self): + return len(self._chosen_data) class Dataset: @@ -158,7 +135,7 @@ def create_dataset( # Add DPO dataset support if "chosen" in sample and "rejected" in sample: - return DPODataset(data, tokenizer) + return ORPODataset(data, tokenizer) elif "messages" in sample: return ChatDataset(data, tokenizer) elif prompt_feature in sample and completion_feature in sample: diff --git a/llms/mlx_lm/tuner/dpo_trainer.py b/llms/mlx_lm/tuner/dpo_trainer.py deleted file mode 100644 index 657daf28..00000000 --- a/llms/mlx_lm/tuner/dpo_trainer.py +++ /dev/null @@ -1,457 +0,0 @@ -# Copyright © 2024 Apple Inc. - -import glob -import shutil -import time -from dataclasses import dataclass, field -from pathlib import Path -from typing import Union - -import mlx.core as mx -import mlx.nn as nn -import numpy as np -from mlx.nn.utils import average_gradients -from mlx.utils import tree_flatten -from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs - - -@dataclass -class DPOTrainingArgs(TrainingArgs): - beta: float = field( - default=0.1, - metadata={"help": "Temperature parameter for DPO training."} - ) - loss_type: str = field( - default="sigmoid", - metadata={ - "help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'." - } - ) - is_reference_free: bool = field( - default=False, - metadata={ - "help": "Whether to use reference-free DPO training." - } - ) - delta: float = field( - default=50.0, - metadata={ - "help": "Delta parameter for DPOP loss type." - } - ) - reference_model_path: str = field( - default=None, - metadata={ - "help": "Path to reference model weights. If None, uses the same model." - } - ) - seed: int = field( - default=42, - metadata={ - "help": "Random seed for reproducibility." - } - ) - - -def dpo_loss( - model, - reference_teacher_model, - chosen: mx.array, - rejected: mx.array, - chosen_masks: mx.array, - rejected_masks: mx.array, - beta: float, - delta: float, - loss_type: str = "sigmoid", - is_reference_free: bool = False -): - """ - Calculate loss for inputs. - Args: - inputs: Input tokens. - targets: Target tokens. - lengths: Lengths of inputs. - Returns: - Loss value. - """ - def make_predictions(model, x, mask): - inputs = x[:, :-1] - targets = x[:, 1:] - - logits = model(inputs) - logits = logits.astype(mx.float32) - - return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] - - num_chosen_tokens = chosen_masks.sum(-1) - num_rejected_tokens = rejected_masks.sum(-1) - - # Calculate log probabilities for policy model - policy_chosen_scores = make_predictions(model, chosen, chosen_masks) - policy_rejected_scores = make_predictions(model, rejected, rejected_masks) - if loss_type == "ipo": - # ipo uses average log probabilities - policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens - policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens - else: - policy_chosen_score = policy_chosen_scores.sum(-1) - policy_rejected_score = policy_rejected_scores.sum(-1) - - # Calculate log probabilities for reference model - if is_reference_free: - reference_chosen_score = mx.zeros_like(policy_chosen_score) - reference_rejected_score = mx.zeros_like(policy_rejected_score) - else: - reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks)) - reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks)) - if loss_type == "ipo": - # ipo uses average log probabilities - reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens - reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens - else: - reference_chosen_score = reference_chosen_scores.sum(-1) - reference_rejected_score = reference_rejected_scores.sum(-1) - - logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score) - - if loss_type == "sigmoid": - losses = -nn.log_sigmoid(beta * logits) - elif loss_type == "hinge": - losses = nn.relu(1 - beta * logits) - elif loss_type == "ipo": - losses = (logits - 1 / (2 * beta)) ** 2 - elif loss_type == "dpop": - delta = 50 - penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score) - losses = -(nn.log_sigmoid(beta * logits) - delta * penalty) - else: - raise ValueError(f"Unknown loss type: {loss_type}") - - loss = mx.mean(losses) - num_tokens = (num_chosen_tokens + num_rejected_tokens).sum() - - chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score) - rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score) - reward = mx.stack([chosen_reward, rejected_reward]) - - return loss, reward, num_tokens - - -def iterate_dpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): - """ - Modified iterate_batches for DPO training that handles chosen and rejected samples. - """ - # Sort pairs by length of the chosen response - idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen'])) - if len(dataset) < batch_size: - raise ValueError( - f"Dataset must have at least batch_size={batch_size}" - f" examples but only has {len(dataset)}." - ) - - step = mx.distributed.init().size() - if batch_size % step != 0: - raise ValueError("The batch size must be divisible by the number of workers") - - batch_idx = [ - idx[i : i + batch_size : step] - for i in range(0, len(idx) - batch_size + 1, batch_size) - ] - - while True: - indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx)) - for i in indices: - batch = [dataset[j] for j in batch_idx[i]] - - # Get lengths for chosen and rejected sequences - chosen_lengths = [len(x['chosen']) for x in batch] - rejected_lengths = [len(x['rejected']) for x in batch] - max_length = max(max(chosen_lengths), max(rejected_lengths)) - - if max_length > max_seq_length: - print( - f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " - f"The longest sequence {max_length} will be truncated to {max_seq_length}." - ) - - # Pad to nearest multiple of 8 - pad_to = 8 - max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to) - max_length_in_batch = min(max_length_in_batch, max_seq_length) - - # Create arrays for chosen and rejected sequences - chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) - - # Create attention masks - chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) - rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32) - - for j in range(batch_size // step): - # Process chosen sequence - chosen_length = min(chosen_lengths[j], max_seq_length) - chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length] - chosen_masks[j, :chosen_length] = 1.0 - - # Process rejected sequence - rejected_length = min(rejected_lengths[j], max_seq_length) - rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length] - rejected_masks[j, :rejected_length] = 1.0 - - yield (mx.array(chosen_arr), mx.array(rejected_arr), - mx.array(chosen_masks), mx.array(rejected_masks)) - - if not train: - break - - -def evaluate_dpo( - model, - reference_model, - dataset, - tokenizer, - batch_size, - num_batches, - beta: float, - delta: float, - max_seq_length=2048, - loss_fn: callable = dpo_loss, - loss_type="sigmoid", -): - """ - Modified evaluate function for DPO training. - """ - all_losses = 0 - all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward] - ntokens = 0 - - index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - - for _, batch in zip( - index_iterator, - iterate_dpo_batches( - dataset=dataset, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_length=max_seq_length, - ), - ): - chosen, rejected, chosen_masks, rejected_masks = batch - loss, reward, toks = loss_fn( - model=model, - reference_teacher_model=reference_model, - chosen=chosen, - rejected=rejected, - chosen_masks=chosen_masks, - rejected_masks=rejected_masks, - loss_type=loss_type, - beta=beta, - delta=delta, - ) - - all_losses += loss * toks - all_rewards += reward - ntokens += toks - mx.eval(all_losses, all_rewards, ntokens) - - all_losses = mx.distributed.all_sum(all_losses) - all_rewards = mx.distributed.all_sum(all_rewards) - ntokens = mx.distributed.all_sum(ntokens) - - return (all_losses / ntokens).item(), all_rewards.tolist() - -def train_dpo( - model, - reference_model, - tokenizer, - optimizer, - train_dataset, - val_dataset, - args: DPOTrainingArgs = DPOTrainingArgs(), - loss_fn: callable = dpo_loss, - training_callback: TrainingCallback = None, - loss_type="sigmoid", -): - """ - Modified training function for DPO. - """ - print(f"Starting DPO training..., iters: {args.iters}") - world = mx.distributed.init() - world_size = world.size() - rank = world.rank() - if world_size > 1: - print(f"Node {rank} of {world_size}") - - if args.grad_checkpoint: - grad_checkpoint(model.layers[0]) - - state = [model.state, optimizer.state] - - def step(batch): - chosen, rejected, chosen_masks, rejected_masks = batch - - # Remove loss_type from the call - (loss, reward, toks), grad = loss_value_and_grad( - model, - reference_model, - chosen, - rejected, - chosen_masks, - rejected_masks - ) - - # All reduce the gradients if running in distributed mode - grad = average_gradients(grad) - - # Model update - optimizer.update(model, grad) - - return loss, reward, toks - - # Create a wrapper function that includes all required arguments - def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks): - return loss_fn( - model=model, - reference_teacher_model=ref_model, - chosen=chosen, - rejected=rejected, - chosen_masks=chosen_masks, - rejected_masks=rejected_masks, - beta=args.beta, - delta=args.delta, - loss_type=loss_type, - is_reference_free=args.is_reference_free - ) - - # Create value_and_grad with the wrapper - loss_value_and_grad = nn.value_and_grad(model, loss_wrapper) - - losses = 0 - rewards = mx.zeros((2,)) - n_tokens = 0 - steps = 0 - trained_tokens = 0 - - # Main training loop - start = time.perf_counter() - for it, batch in zip( - range(1, args.iters + 1), - iterate_dpo_batches( - dataset=train_dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - max_seq_length=args.max_seq_length, - train=True, - ), - ): - # Report validation loss if needed - if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: - stop = time.perf_counter() - val_loss, val_rewards = evaluate_dpo( - model=model, - reference_model=reference_model, - dataset=val_dataset, - tokenizer=tokenizer, - batch_size=args.batch_size, - num_batches=args.val_batches, - max_seq_length=args.max_seq_length, - loss_fn=loss_fn, - beta=args.beta, - delta=args.delta, - loss_type=loss_type, - ) - val_time = time.perf_counter() - stop - if rank == 0: - print( - f"Iter {it}: " - f"Val loss {val_loss:.3f}, " - f"Val chosen reward {val_rewards[0]:.3f}, " - f"Val rejected reward {val_rewards[1]:.3f}, " - f"Val took {val_time:.3f}s", - flush=True, - ) - - if training_callback is not None: - val_info = { - "iteration": it, - "val_loss": val_loss, - "val_chosen_reward": val_rewards[0], - "val_rejected_reward": val_rewards[1], - "val_time": val_time, - } - training_callback.on_val_loss_report(val_info) - - start = time.perf_counter() - - loss, reward, toks = step(batch) - losses += loss - rewards += reward - n_tokens += toks - steps += 1 - mx.eval(state, losses, rewards, n_tokens) - - # Report training loss if needed - if it % args.steps_per_report == 0 or it == args.iters: - stop = time.perf_counter() - - train_loss = mx.distributed.all_sum(losses).item() - train_loss /= steps * world_size - train_rewards = mx.distributed.all_sum(rewards).tolist() - train_rewards = [r / (steps * world_size) for r in train_rewards] - n_tokens = mx.distributed.all_sum(n_tokens).item() - learning_rate = optimizer.learning_rate.item() - it_sec = args.steps_per_report / (stop - start) - tokens_sec = float(n_tokens) / (stop - start) - trained_tokens += n_tokens - peak_mem = mx.metal.get_peak_memory() / 1e9 - - if rank == 0: - print( - f"Iter {it}: Train loss {train_loss:.3f}, " - f"Chosen reward {train_rewards[0]:.3f}, " - f"Rejected reward {train_rewards[1]:.3f}, " - f"Learning Rate {learning_rate:.3e}, " - f"It/sec {it_sec:.3f}, " - f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " - f"Peak mem {peak_mem:.3f} GB", - flush=True, - ) - - if training_callback is not None: - train_info = { - "iteration": it, - "train_loss": train_loss, - "train_chosen_reward": train_rewards[0], - "train_rejected_reward": train_rewards[1], - "learning_rate": learning_rate, - "iterations_per_second": it_sec, - "tokens_per_second": tokens_sec, - "trained_tokens": trained_tokens, - "peak_memory": peak_mem, - } - training_callback.on_train_loss_report(train_info) - - losses = 0 - rewards = mx.zeros((2,)) - n_tokens = 0 - steps = 0 - start = time.perf_counter() - - # Save adapter weights - if it % args.steps_per_save == 0: - adapter_weights = dict(tree_flatten(model.trainable_parameters())) - mx.save_safetensors(str(args.adapter_file), adapter_weights) - checkpoint = ( - Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors" - ) - mx.save_safetensors(str(checkpoint), adapter_weights) - print( - f"Iter {it}: Saved adapter weights to " - f"{args.adapter_file} and {checkpoint}." - ) - - # Save final weights - adapter_weights = dict(tree_flatten(model.trainable_parameters())) - mx.save_safetensors(str(args.adapter_file), adapter_weights) - print(f"Saved final weights to {args.adapter_file}.") \ No newline at end of file diff --git a/llms/mlx_lm/tuner/orpo_trainer.py b/llms/mlx_lm/tuner/orpo_trainer.py index cadfb049..6963ca40 100644 --- a/llms/mlx_lm/tuner/orpo_trainer.py +++ b/llms/mlx_lm/tuner/orpo_trainer.py @@ -14,128 +14,48 @@ from .trainer import TrainingArgs, grad_checkpoint, TrainingCallback class ORPOTrainingArgs(TrainingArgs): beta: float = field( default=0.1, - metadata={"help": "Temperature parameter for DPO training."} - ) - reward_scaling: float = field( - default=1.0, - metadata={"help": "Scaling factor for offline rewards."} + metadata={"help": "Temperature parameter for ORPO training."} ) -def orpo_loss( - model, - chosen: mx.array, - rejected: mx.array, - chosen_masks: mx.array, - rejected_masks: mx.array, - chosen_rewards: mx.array, - rejected_rewards: mx.array, - beta: float = 0.1, - reward_scaling: float = 1.0, -): - """ - Calculate ORPO loss using pre-computed rewards that incorporate preference scores. - Args: - model: Policy model - chosen: Chosen sequence tokens - rejected: Rejected sequence tokens - chosen_masks: Attention masks for chosen sequences - rejected_masks: Attention masks for rejected sequences - chosen_rewards: Rewards for chosen sequences (derived from preference scores) - rejected_rewards: Rewards for rejected sequences (derived from preference scores) - beta: Temperature parameter - reward_scaling: Scaling factor for rewards - Returns: - Loss value, rewards, and number of tokens. - """ - def make_predictions(model, x, mask): - inputs = x[:, :-1] - targets = x[:, 1:] - - logits = model(inputs) - logits = logits.astype(mx.float32) - - return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1] +def orpo_loss(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards, beta=0.1): + def get_logps(model, x, mask): + inputs = x[:, :-1] + targets = x[:, 1:] + logits = model(inputs) + logp = -nn.losses.cross_entropy(logits, targets, reduction='none') + seq_lengths = mask[:, :-1].sum(-1) + logp_sum = (logp * mask[:, :-1]).sum(-1) / seq_lengths + logits_mean = (logits * mask[:, :-1, None]).sum() / mask[:, :-1].sum() + return logp_sum, logits_mean - # Calculate log probabilities for policy model - policy_chosen_scores = make_predictions(model, chosen, chosen_masks) - policy_rejected_scores = make_predictions(model, rejected, rejected_masks) - - # Scale the pre-computed rewards - chosen_rewards = chosen_rewards * reward_scaling - rejected_rewards = rejected_rewards * reward_scaling - - # Calculate reward difference - reward_diff = chosen_rewards - rejected_rewards - - # Calculate ORPO loss using logistic function - policy_diff = policy_chosen_scores.sum(-1) - policy_rejected_scores.sum(-1) - loss = -nn.log_sigmoid(beta * (policy_diff * reward_diff)) - - loss = mx.mean(loss) - - # Calculate number of tokens for logging - num_tokens = (chosen_masks.sum() + rejected_masks.sum()) - - # Calculate rewards for logging - avg_chosen_reward = mx.mean(chosen_rewards) - avg_rejected_reward = mx.mean(rejected_rewards) - reward = mx.stack([avg_chosen_reward, avg_rejected_reward]) - - return loss, reward, num_tokens - - -def evaluate_orpo( - model, - dataset, - tokenizer, - batch_size, - num_batches, - beta: float, - reward_scaling: float = 1.0, - max_seq_length=2048, -): - """ - Evaluation function for ORPO. - """ - all_losses = 0 - all_rewards = mx.zeros((2,)) - ntokens = 0 - - index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - - for _, batch in zip( - index_iterator, - iterate_orpo_batches( - dataset=dataset, - tokenizer=tokenizer, - batch_size=batch_size, - max_seq_length=max_seq_length, - ), - ): - chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch - loss, reward, toks = orpo_loss( - model=model, - chosen=chosen, - rejected=rejected, - chosen_masks=chosen_masks, - rejected_masks=rejected_masks, - chosen_rewards=chosen_rewards, - rejected_rewards=rejected_rewards, - beta=beta, - reward_scaling=reward_scaling, - ) - - all_losses += loss * toks - all_rewards += reward - ntokens += toks - mx.eval(all_losses, all_rewards, ntokens) - - all_losses = mx.distributed.all_sum(all_losses) - all_rewards = mx.distributed.all_sum(all_rewards) - ntokens = mx.distributed.all_sum(ntokens) - - return (all_losses / ntokens).item(), all_rewards.tolist() + policy_chosen_logps, chosen_logits_mean = get_logps(model, chosen, chosen_masks) + policy_rejected_logps, rejected_logits_mean = get_logps(model, rejected, rejected_masks) + + log_odds = (policy_chosen_logps - policy_rejected_logps) - ( + mx.log1p(-mx.exp(policy_chosen_logps)) - mx.log1p(-mx.exp(policy_rejected_logps)) + ) + + ratio = nn.log_sigmoid(log_odds) + loss = -beta * ratio + + accuracies = (log_odds > 0).astype(mx.float32) + margins = mx.mean(ratio) + metrics = { + 'accuracies': mx.mean(accuracies), + 'margins': margins, + 'policy_rejected_logps': mx.mean(policy_rejected_logps), + 'policy_chosen_logps': mx.mean(policy_chosen_logps), + 'rejected_logits_mean': mx.mean(rejected_logits_mean), + 'chosen_logits_mean': mx.mean(chosen_logits_mean) + } + + chosen_reward = beta * policy_chosen_logps + rejected_reward = beta * policy_rejected_logps + reward = mx.stack([mx.mean(chosen_reward), mx.mean(rejected_reward)]) + num_tokens = chosen_masks.sum() + rejected_masks.sum() + + return mx.mean(loss), reward, num_tokens, metrics def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): @@ -188,10 +108,6 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F # Get preference scores and convert to rewards preference_scores = np.array([x.get('preference_score', 1.0) for x in batch], np.float32) - # Convert preference scores to chosen/rejected rewards - # When preference_score is 1.0, chosen_reward=1.0, rejected_reward=0.0 - # When preference_score is 0.0, chosen_reward=0.0, rejected_reward=1.0 - # When preference_score is 0.5, both rewards are 0.5 chosen_rewards = preference_scores rejected_rewards = 1.0 - preference_scores @@ -218,6 +134,56 @@ def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=F break +def evaluate_orpo(model, dataset, tokenizer, batch_size, num_batches, beta: float, max_seq_length=2048): + all_losses = 0 + all_rewards = mx.zeros((2,)) + all_metrics = None + ntokens = 0 + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + for _, batch in zip( + index_iterator, + iterate_orpo_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch + loss, reward, toks, metrics = orpo_loss( + model=model, + chosen=chosen, + rejected=rejected, + chosen_masks=chosen_masks, + rejected_masks=rejected_masks, + chosen_rewards=chosen_rewards, + rejected_rewards=rejected_rewards, + beta=beta + ) + all_losses += loss * toks + all_rewards += reward * toks + ntokens += toks + + if all_metrics is None: + all_metrics = {k: v * toks for k, v in metrics.items()} + else: + for k, v in metrics.items(): + all_metrics[k] += v * toks + + mx.eval(all_losses, all_rewards, ntokens) + all_losses = mx.distributed.all_sum(all_losses) + all_rewards = mx.distributed.all_sum(all_rewards) + ntokens = mx.distributed.all_sum(ntokens) + all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()} + + avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()} + avg_rewards = (all_rewards / ntokens).tolist() + avg_loss = (all_losses / ntokens).item() + + return avg_loss, avg_rewards, ntokens, avg_metrics + + def train_orpo( model, tokenizer, @@ -227,9 +193,6 @@ def train_orpo( args: ORPOTrainingArgs = ORPOTrainingArgs(), training_callback: TrainingCallback = None, ): - """ - Training function for ORPO. - """ print(f"Starting ORPO training..., iters: {args.iters}") world = mx.distributed.init() world_size = world.size() @@ -246,7 +209,7 @@ def train_orpo( def step(batch): chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch - (loss, reward, toks), grad = loss_value_and_grad( + (loss, reward, toks, metrics), grad = loss_value_and_grad( model, chosen, rejected, @@ -259,7 +222,7 @@ def train_orpo( grad = average_gradients(grad) optimizer.update(model, grad) - return loss, reward, toks + return loss, reward, toks, metrics def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards): @@ -271,8 +234,7 @@ def train_orpo( rejected_masks=rejected_masks, chosen_rewards=chosen_rewards, rejected_rewards=rejected_rewards, - beta=args.beta, - reward_scaling=args.reward_scaling + beta=args.beta ) loss_value_and_grad = nn.value_and_grad(model, loss_wrapper) @@ -283,11 +245,19 @@ def train_orpo( n_tokens = 0 steps = 0 trained_tokens = 0 + accumulated_metrics = { + 'accuracies': 0, + 'margins': 0, + 'policy_rejected_logps': 0, + 'policy_chosen_logps': 0, + 'rejected_logits_mean': 0, + 'chosen_logits_mean': 0 + } start = time.perf_counter() for it, batch in zip( range(1, args.iters + 1), - iterate_orpo_batches( # reuse DPO batch iterator + iterate_orpo_batches( dataset=train_dataset, tokenizer=tokenizer, batch_size=args.batch_size, @@ -295,18 +265,16 @@ def train_orpo( train=True, ), ): - # Evaluate if needed if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: stop = time.perf_counter() - val_loss, val_rewards = evaluate_orpo( + val_loss, val_rewards, val_ntokens, val_metrics = evaluate_orpo( model=model, dataset=val_dataset, tokenizer=tokenizer, batch_size=args.batch_size, num_batches=args.val_batches, max_seq_length=args.max_seq_length, - beta=args.beta, - reward_scaling=args.reward_scaling, + beta=args.beta ) val_time = time.perf_counter() - stop if rank == 0: @@ -315,6 +283,8 @@ def train_orpo( f"Val loss {val_loss:.8f}, " f"Val chosen reward {val_rewards[0]:.3f}, " f"Val rejected reward {val_rewards[1]:.3f}, " + f"Val accuracy {val_metrics['accuracies']:.3f}, " + f"Val margin {val_metrics['margins']:.3f}, " f"Val took {val_time:.3f}s", flush=True, ) @@ -325,25 +295,28 @@ def train_orpo( "val_loss": val_loss, "val_chosen_reward": val_rewards[0], "val_rejected_reward": val_rewards[1], + **{f"val_{k}": v for k, v in val_metrics.items()}, "val_time": val_time, }) start = time.perf_counter() # Training step - loss, reward, toks = step(batch) + loss, reward, toks, metrics = step(batch) losses += loss rewards += reward n_tokens += toks steps += 1 + for k, v in metrics.items(): + accumulated_metrics[k] += v mx.eval(state, losses, rewards, n_tokens) - # Report training metrics if needed if it % args.steps_per_report == 0 or it == args.iters: stop = time.perf_counter() train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size) train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()] + avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()} n_tokens = mx.distributed.all_sum(n_tokens).item() learning_rate = optimizer.learning_rate.item() it_sec = args.steps_per_report / (stop - start) @@ -356,10 +329,11 @@ def train_orpo( f"Iter {it}: Train loss {train_loss:.8f}, " f"Chosen reward {train_rewards[0]:.3f}, " f"Rejected reward {train_rewards[1]:.3f}, " + f"Accuracy {avg_metrics['accuracies']:.3f}, " + f"Margin {avg_metrics['margins']:.3f}, " f"Learning Rate {learning_rate:.3e}, " f"It/sec {it_sec:.3f}, " f"Tokens/sec {tokens_sec:.3f}, " - f"Trained Tokens {trained_tokens}, " f"Peak mem {peak_mem:.3f} GB", flush=True, ) @@ -370,6 +344,7 @@ def train_orpo( "train_loss": train_loss, "train_chosen_reward": train_rewards[0], "train_rejected_reward": train_rewards[1], + **{f"train_{k}": v for k, v in avg_metrics.items()}, "learning_rate": learning_rate, "iterations_per_second": it_sec, "tokens_per_second": tokens_sec, @@ -381,9 +356,9 @@ def train_orpo( rewards = mx.zeros((2,)) n_tokens = 0 steps = 0 + accumulated_metrics = {k: 0 for k in accumulated_metrics} start = time.perf_counter() - # Save model weights if needed if it % args.steps_per_save == 0: adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) @@ -396,7 +371,6 @@ def train_orpo( f"{args.adapter_file} and {checkpoint}." ) - # Save final weights adapter_weights = dict(tree_flatten(model.trainable_parameters())) mx.save_safetensors(str(args.adapter_file), adapter_weights) print(f"Saved final weights to {args.adapter_file}.") \ No newline at end of file