mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 01:41:19 +08:00
finish
This commit is contained in:
parent
7d279b51ef
commit
fa80d081f2
@ -20,6 +20,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
|
||||
- [Run](#Run)
|
||||
- [Fine-tune](#Fine-tune)
|
||||
- [DPO Training](#DPO Training)
|
||||
- [ORPO Training](#ORPO Training)
|
||||
- [Evaluate](#Evaluate)
|
||||
- [Generate](#Generate)
|
||||
- [Fuse](#Fuse)
|
||||
@ -105,6 +106,38 @@ For DPO training, the data should be in JSONL format with the following structur
|
||||
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
|
||||
```
|
||||
|
||||
Here's the equivalent ORPO documentation:
|
||||
|
||||
### ORPO Training
|
||||
|
||||
Offline Reward Policy Optimization (ORPO) training allows you to fine-tune models using human preference data with pre-computed rewards. To use ORPO training, set the training mode to 'orpo':
|
||||
|
||||
```shell
|
||||
mlx_lm.lora \
|
||||
--model <path_to_model> \
|
||||
--train \
|
||||
--training-mode orpo \
|
||||
--data <path_to_data> \
|
||||
--beta 0.1 \
|
||||
--reward-scaling 1.0
|
||||
```
|
||||
|
||||
The ORPO training accepts the following additional parameters:
|
||||
- `--beta`: Controls the temperature parameter for the logistic function (default: 0.1)
|
||||
- `--reward-scaling`: Scaling factor for the offline rewards (default: 1.0)
|
||||
|
||||
For ORPO training, the data should be in JSONL format with the following structure:
|
||||
|
||||
```jsonl
|
||||
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
|
||||
```
|
||||
|
||||
The training process will automatically assign binary rewards (1.0 for chosen and 0.0 for rejected responses) if no explicit rewards are provided. You can also provide custom rewards in your data:
|
||||
|
||||
```jsonl
|
||||
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "chosen_reward": 0.8, "rejected_reward": 0.3}
|
||||
```
|
||||
|
||||
### Evaluate
|
||||
|
||||
To compute test set perplexity use:
|
||||
|
@ -16,6 +16,7 @@ from .tokenizer_utils import TokenizerWrapper
|
||||
from .tuner.datasets import load_dataset
|
||||
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
|
||||
from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo
|
||||
from .tuner.orpo_trainer import ORPOTrainingArgs, evaluate_orpo, train_orpo
|
||||
from .tuner.utils import (
|
||||
build_schedule,
|
||||
linear_to_lora_layers,
|
||||
@ -70,6 +71,7 @@ CONFIG_DEFAULTS = {
|
||||
"delta": 50.0,
|
||||
"reference_model_path": None,
|
||||
"train_bias_only": False,
|
||||
"reward_scaling": 1.0,
|
||||
}
|
||||
|
||||
|
||||
@ -106,7 +108,7 @@ def build_parser():
|
||||
"--training-mode",
|
||||
type=str,
|
||||
choices=["normal", "dpo", "orpo"],
|
||||
help="Training mode: normal, DPO or ORPO",
|
||||
help="Training mode: normal, DPO or ORPO.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-layers",
|
||||
@ -149,7 +151,7 @@ def build_parser():
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Evaluate on the test set after training",
|
||||
help="Evaluate on the test set after training.",
|
||||
default=None,
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -166,7 +168,7 @@ def build_parser():
|
||||
"-c",
|
||||
"--config",
|
||||
type=str,
|
||||
help="A YAML configuration file with the training options",
|
||||
help="A YAML configuration file with the training options.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-checkpoint",
|
||||
@ -180,7 +182,8 @@ def build_parser():
|
||||
parser.add_argument("--delta", type=float)
|
||||
parser.add_argument("--reference-model-path", type=str)
|
||||
parser.add_argument("--train-bias-only", action="store_true")
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed")
|
||||
parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.")
|
||||
parser.add_argument("--seed", type=int, help="The PRNG seed.")
|
||||
return parser
|
||||
|
||||
|
||||
@ -226,7 +229,8 @@ def train_model(
|
||||
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
|
||||
)
|
||||
)
|
||||
# Train model
|
||||
|
||||
# Train model based on training mode
|
||||
if args.training_mode == "dpo":
|
||||
training_args = DPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
@ -261,6 +265,32 @@ def train_model(
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
elif args.training_mode == "orpo":
|
||||
training_args = ORPOTrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
iters=args.iters,
|
||||
val_batches=args.val_batches,
|
||||
steps_per_report=args.steps_per_report,
|
||||
steps_per_eval=args.steps_per_eval,
|
||||
steps_per_save=args.save_every,
|
||||
adapter_file=adapter_file,
|
||||
max_seq_length=args.max_seq_length,
|
||||
grad_checkpoint=args.grad_checkpoint,
|
||||
beta=args.beta,
|
||||
reward_scaling=args.reward_scaling,
|
||||
train_bias_only=args.train_bias_only,
|
||||
seed=args.seed,
|
||||
)
|
||||
|
||||
train_orpo(
|
||||
model=model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=opt,
|
||||
train_dataset=train_set,
|
||||
val_dataset=valid_set,
|
||||
args=training_args,
|
||||
training_callback=training_callback,
|
||||
)
|
||||
else:
|
||||
training_args = TrainingArgs(
|
||||
batch_size=args.batch_size,
|
||||
@ -304,7 +334,19 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
delta=args.delta,
|
||||
loss_type=args.loss_type,
|
||||
loss_type=args.dpo_loss_type,
|
||||
)
|
||||
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||
elif args.training_mode == "orpo":
|
||||
test_loss, test_rewards = evaluate_orpo(
|
||||
model=model,
|
||||
dataset=test_set,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.test_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
reward_scaling=args.reward_scaling,
|
||||
)
|
||||
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
|
||||
else:
|
||||
@ -318,7 +360,6 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
|
||||
)
|
||||
|
||||
test_ppl = math.exp(test_loss)
|
||||
|
||||
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")
|
||||
|
||||
|
||||
|
@ -12,7 +12,6 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from ..generate import generate
|
||||
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
|
||||
|
||||
|
||||
|
@ -1,54 +1,48 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from dpo_trainer import DPOTrainingArgs, iterate_dpo_batches, train_dpo, TrainingCallback
|
||||
import mlx.nn as nn
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx.nn.utils import average_gradients
|
||||
from .dpo_trainer import DPOTrainingArgs, grad_checkpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class ORPOTrainingArgs(DPOTrainingArgs):
|
||||
"""
|
||||
Training arguments specific to ORPO, extending DPO arguments.
|
||||
"""
|
||||
mu: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "ORPO KL divergence weight parameter"}
|
||||
reward_scaling: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Scaling factor for offline rewards."}
|
||||
)
|
||||
|
||||
|
||||
def orpo_loss(
|
||||
model,
|
||||
reference_teacher_model,
|
||||
chosen: mx.array,
|
||||
rejected: mx.array,
|
||||
chosen_masks: mx.array,
|
||||
rejected_masks: mx.array,
|
||||
chosen_rewards: mx.array, # Pre-computed rewards
|
||||
rejected_rewards: mx.array, # Pre-computed rewards
|
||||
beta: float,
|
||||
delta: float,
|
||||
mu: float = 0.5, # ORPO hyperparameter for balancing KL divergence
|
||||
loss_type: str = "sigmoid",
|
||||
is_reference_free: bool = False
|
||||
reward_scaling: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Calculate ORPO loss for inputs.
|
||||
ORPO extends DPO by adding a KL regularization term to prevent overfitting to preferences.
|
||||
|
||||
Calculate ORPO loss using pre-computed rewards.
|
||||
Args:
|
||||
model: Policy model
|
||||
reference_teacher_model: Reference model
|
||||
chosen: Chosen sequence tokens
|
||||
rejected: Rejected sequence tokens
|
||||
chosen_masks: Masks for chosen sequences
|
||||
rejected_masks: Masks for rejected sequences
|
||||
chosen_masks: Attention masks for chosen sequences
|
||||
rejected_masks: Attention masks for rejected sequences
|
||||
chosen_rewards: Pre-computed rewards for chosen sequences
|
||||
rejected_rewards: Pre-computed rewards for rejected sequences
|
||||
beta: Temperature parameter
|
||||
delta: Margin for DPOP loss type
|
||||
mu: ORPO hyperparameter for balancing KL divergence (default: 0.5)
|
||||
loss_type: Loss type ('sigmoid', 'hinge', 'ipo', or 'dpop')
|
||||
is_reference_free: Whether to use reference-free training
|
||||
reward_scaling: Scaling factor for rewards
|
||||
Returns:
|
||||
Tuple of (loss, reward, num_tokens)
|
||||
Loss value, rewards, and number of tokens.
|
||||
"""
|
||||
def make_predictions(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
@ -59,218 +53,335 @@ def orpo_loss(
|
||||
|
||||
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
|
||||
|
||||
num_chosen_tokens = chosen_masks.sum(-1)
|
||||
num_rejected_tokens = rejected_masks.sum(-1)
|
||||
|
||||
# Calculate log probabilities for policy model
|
||||
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
|
||||
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
|
||||
|
||||
# Calculate reference model scores
|
||||
if not is_reference_free:
|
||||
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
|
||||
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
|
||||
else:
|
||||
reference_chosen_scores = mx.zeros_like(policy_chosen_scores)
|
||||
reference_rejected_scores = mx.zeros_like(policy_rejected_scores)
|
||||
|
||||
# Compute average log probabilities if using IPO loss
|
||||
if loss_type == "ipo":
|
||||
policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens
|
||||
policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens
|
||||
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
|
||||
reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens
|
||||
else:
|
||||
policy_chosen_score = policy_chosen_scores.sum(-1)
|
||||
policy_rejected_score = policy_rejected_scores.sum(-1)
|
||||
reference_chosen_score = reference_chosen_scores.sum(-1)
|
||||
reference_rejected_score = reference_rejected_scores.sum(-1)
|
||||
|
||||
# Calculate preference logits
|
||||
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
|
||||
|
||||
# Calculate preference loss based on loss type
|
||||
if loss_type == "sigmoid":
|
||||
preference_loss = -nn.log_sigmoid(beta * logits)
|
||||
elif loss_type == "hinge":
|
||||
preference_loss = nn.relu(1 - beta * logits)
|
||||
elif loss_type == "ipo":
|
||||
preference_loss = (logits - 1 / (2 * beta)) ** 2
|
||||
elif loss_type == "dpop":
|
||||
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
|
||||
preference_loss = -(nn.log_sigmoid(beta * logits) - delta * penalty)
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||
|
||||
# Calculate KL divergence term for ORPO
|
||||
kl_div_chosen = mx.mean((policy_chosen_scores - reference_chosen_scores) ** 2)
|
||||
kl_div_rejected = mx.mean((policy_rejected_scores - reference_rejected_scores) ** 2)
|
||||
kl_regularization = mu * (kl_div_chosen + kl_div_rejected)
|
||||
|
||||
# Combine preference loss and KL regularization
|
||||
loss = mx.mean(preference_loss) + kl_regularization
|
||||
# Scale the pre-computed rewards
|
||||
chosen_rewards = chosen_rewards * reward_scaling
|
||||
rejected_rewards = rejected_rewards * reward_scaling
|
||||
|
||||
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
|
||||
|
||||
# Calculate rewards for monitoring
|
||||
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
|
||||
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
|
||||
reward = mx.stack([chosen_reward, rejected_reward])
|
||||
# ORPO uses the reward difference directly
|
||||
reward_diff = chosen_rewards - rejected_rewards
|
||||
|
||||
# Calculate ORPO loss using logistic function
|
||||
policy_diff = policy_chosen_scores.sum(-1) - policy_rejected_scores.sum(-1)
|
||||
loss = -nn.log_sigmoid(beta * (policy_diff * reward_diff))
|
||||
|
||||
loss = mx.mean(loss)
|
||||
|
||||
# Calculate number of tokens for logging
|
||||
num_tokens = (chosen_masks.sum() + rejected_masks.sum())
|
||||
|
||||
# Calculate rewards for logging
|
||||
avg_chosen_reward = mx.mean(chosen_rewards)
|
||||
avg_rejected_reward = mx.mean(rejected_rewards)
|
||||
reward = mx.stack([avg_chosen_reward, avg_rejected_reward])
|
||||
|
||||
return loss, reward, num_tokens
|
||||
|
||||
|
||||
def evaluate_orpo(
|
||||
model,
|
||||
reference_model,
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
num_batches,
|
||||
beta: float,
|
||||
delta: float,
|
||||
mu: float = 0.5,
|
||||
reward_scaling: float = 1.0,
|
||||
max_seq_length=2048,
|
||||
loss_type="sigmoid",
|
||||
is_reference_free=False,
|
||||
):
|
||||
"""
|
||||
Evaluate model using ORPO metrics.
|
||||
|
||||
Args:
|
||||
model: Policy model to evaluate
|
||||
reference_model: Reference model for comparison
|
||||
dataset: Evaluation dataset
|
||||
tokenizer: Tokenizer for processing text
|
||||
batch_size: Batch size for evaluation
|
||||
num_batches: Number of batches to evaluate (-1 for full dataset)
|
||||
beta: Temperature parameter
|
||||
delta: Margin for DPOP loss
|
||||
mu: ORPO KL divergence weight
|
||||
max_seq_length: Maximum sequence length
|
||||
loss_type: Type of loss function
|
||||
is_reference_free: Whether to use reference-free evaluation
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, rewards, kl_metrics), where:
|
||||
- loss is the total ORPO loss
|
||||
- rewards is [chosen_reward, rejected_reward]
|
||||
- kl_metrics is [chosen_kl, rejected_kl]
|
||||
Evaluation function for ORPO.
|
||||
"""
|
||||
all_losses = 0
|
||||
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
|
||||
all_kl_divs = mx.zeros((2,)) # [chosen_kl, rejected_kl]
|
||||
all_rewards = mx.zeros((2,))
|
||||
ntokens = 0
|
||||
|
||||
def compute_kl_divergence(policy_scores, reference_scores, masks):
|
||||
"""Helper function to compute KL divergence metrics."""
|
||||
# Using MSE as a proxy for KL divergence as in the loss function
|
||||
valid_tokens = masks.sum()
|
||||
kl_div = ((policy_scores - reference_scores) ** 2 * masks).sum() / valid_tokens
|
||||
return kl_div
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
for _, batch in zip(
|
||||
index_iterator,
|
||||
iterate_dpo_batches( # Reusing DPO batch iterator
|
||||
iterate_orpo_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
chosen, rejected, chosen_masks, rejected_masks = batch
|
||||
|
||||
# Get model predictions
|
||||
def make_predictions(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
targets = x[:, 1:]
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
|
||||
|
||||
# Get scores for both models
|
||||
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
|
||||
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
|
||||
|
||||
if not is_reference_free:
|
||||
reference_chosen_scores = mx.stop_gradient(
|
||||
make_predictions(reference_model, chosen, chosen_masks)
|
||||
)
|
||||
reference_rejected_scores = mx.stop_gradient(
|
||||
make_predictions(reference_model, rejected, rejected_masks)
|
||||
)
|
||||
else:
|
||||
reference_chosen_scores = mx.zeros_like(policy_chosen_scores)
|
||||
reference_rejected_scores = mx.zeros_like(policy_rejected_scores)
|
||||
|
||||
# Compute KL divergences
|
||||
chosen_kl = compute_kl_divergence(
|
||||
policy_chosen_scores, reference_chosen_scores, chosen_masks[:, :-1]
|
||||
)
|
||||
rejected_kl = compute_kl_divergence(
|
||||
policy_rejected_scores, reference_rejected_scores, rejected_masks[:, :-1]
|
||||
)
|
||||
all_kl_divs += mx.stack([chosen_kl, rejected_kl])
|
||||
|
||||
# Compute ORPO loss and rewards
|
||||
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
|
||||
loss, reward, toks = orpo_loss(
|
||||
model=model,
|
||||
reference_teacher_model=reference_model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
rejected_masks=rejected_masks,
|
||||
chosen_rewards=chosen_rewards,
|
||||
rejected_rewards=rejected_rewards,
|
||||
beta=beta,
|
||||
delta=delta,
|
||||
mu=mu,
|
||||
loss_type=loss_type,
|
||||
is_reference_free=is_reference_free,
|
||||
reward_scaling=reward_scaling,
|
||||
)
|
||||
|
||||
all_losses += loss * toks
|
||||
all_rewards += reward
|
||||
ntokens += toks
|
||||
mx.eval(all_losses, all_rewards, all_kl_divs, ntokens)
|
||||
mx.eval(all_losses, all_rewards, ntokens)
|
||||
|
||||
# Aggregate metrics across distributed workers if necessary
|
||||
all_losses = mx.distributed.all_sum(all_losses)
|
||||
all_rewards = mx.distributed.all_sum(all_rewards)
|
||||
all_kl_divs = mx.distributed.all_sum(all_kl_divs)
|
||||
ntokens = mx.distributed.all_sum(ntokens)
|
||||
|
||||
# Normalize metrics
|
||||
avg_loss = (all_losses / ntokens).item()
|
||||
avg_rewards = [r / mx.distributed.init().size() for r in all_rewards.tolist()]
|
||||
avg_kl_divs = [kl / mx.distributed.init().size() for kl in all_kl_divs.tolist()]
|
||||
return (all_losses / ntokens).item(), all_rewards.tolist()
|
||||
|
||||
return avg_loss, avg_rewards, avg_kl_divs
|
||||
|
||||
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""
|
||||
Modified batch iterator for ORPO that includes pre-computed rewards.
|
||||
Works with pre-tokenized input data.
|
||||
"""
|
||||
# Sort pairs by length of the chosen response
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
||||
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size}"
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
|
||||
# Get lengths assuming data is already tokenized
|
||||
chosen_lengths = [len(x['chosen']) for x in batch]
|
||||
rejected_lengths = [len(x['rejected']) for x in batch]
|
||||
max_length = max(max(chosen_lengths), max(rejected_lengths))
|
||||
|
||||
if max_length > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Sequences longer than {max_seq_length} tokens "
|
||||
f"will be truncated."
|
||||
)
|
||||
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||
|
||||
# Always use binary rewards
|
||||
chosen_rewards = np.ones((batch_size // step,), np.float32)
|
||||
rejected_rewards = np.zeros((batch_size // step,), np.float32)
|
||||
|
||||
for j in range(batch_size // step):
|
||||
# Use pre-tokenized sequences directly
|
||||
chosen_length = min(chosen_lengths[j], max_seq_length)
|
||||
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
||||
chosen_masks[j, :chosen_length] = 1.0
|
||||
|
||||
rejected_length = min(rejected_lengths[j], max_seq_length)
|
||||
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
||||
rejected_masks[j, :rejected_length] = 1.0
|
||||
|
||||
yield (mx.array(chosen_arr), mx.array(rejected_arr),
|
||||
mx.array(chosen_masks), mx.array(rejected_masks),
|
||||
mx.array(chosen_rewards), mx.array(rejected_rewards))
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def train_orpo(
|
||||
model,
|
||||
reference_model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
args: ORPOTrainingArgs = ORPOTrainingArgs(),
|
||||
training_callback: TrainingCallback = None,
|
||||
training_callback = None,
|
||||
):
|
||||
"""
|
||||
Train a model using ORPO (Offline Rejection Preference Optimization).
|
||||
This function adapts the DPO training loop to use ORPO loss.
|
||||
Training function for ORPO.
|
||||
"""
|
||||
return train_dpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=optimizer,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset,
|
||||
args=args,
|
||||
loss_fn=orpo_loss,
|
||||
training_callback=training_callback,
|
||||
loss_type=args.loss_type,
|
||||
)
|
||||
print(f"Starting ORPO training..., iters: {args.iters}")
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
rank = world.rank()
|
||||
|
||||
if world_size > 1:
|
||||
print(f"Node {rank} of {world_size}")
|
||||
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
|
||||
|
||||
(loss, reward, toks), grad = loss_value_and_grad(
|
||||
model,
|
||||
chosen,
|
||||
rejected,
|
||||
chosen_masks,
|
||||
rejected_masks,
|
||||
chosen_rewards,
|
||||
rejected_rewards
|
||||
)
|
||||
|
||||
grad = average_gradients(grad)
|
||||
optimizer.update(model, grad)
|
||||
|
||||
return loss, reward, toks
|
||||
|
||||
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
|
||||
chosen_rewards, rejected_rewards):
|
||||
return orpo_loss(
|
||||
model=model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
rejected_masks=rejected_masks,
|
||||
chosen_rewards=chosen_rewards,
|
||||
rejected_rewards=rejected_rewards,
|
||||
beta=args.beta,
|
||||
reward_scaling=args.reward_scaling
|
||||
)
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
|
||||
|
||||
# Training loop with progress tracking
|
||||
losses = 0
|
||||
rewards = mx.zeros((2,))
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(1, args.iters + 1),
|
||||
iterate_orpo_batches( # reuse DPO batch iterator
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
# Evaluate if needed
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
val_loss, val_rewards = evaluate_orpo(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
reward_scaling=args.reward_scaling,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val chosen reward {val_rewards[0]:.3f}, "
|
||||
f"Val rejected reward {val_rewards[1]:.3f}, "
|
||||
f"Val took {val_time:.3f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_val_loss_report({
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
"val_chosen_reward": val_rewards[0],
|
||||
"val_rejected_reward": val_rewards[1],
|
||||
"val_time": val_time,
|
||||
})
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Training step
|
||||
loss, reward, toks = step(batch)
|
||||
losses += loss
|
||||
rewards += reward
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, rewards, n_tokens)
|
||||
|
||||
# Report training metrics if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
|
||||
train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()]
|
||||
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||
f"Chosen reward {train_rewards[0]:.3f}, "
|
||||
f"Rejected reward {train_rewards[1]:.3f}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
f"Trained Tokens {trained_tokens}, "
|
||||
f"Peak mem {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_train_loss_report({
|
||||
"iteration": it,
|
||||
"train_loss": train_loss,
|
||||
"train_chosen_reward": train_rewards[0],
|
||||
"train_rejected_reward": train_rewards[1],
|
||||
"learning_rate": learning_rate,
|
||||
"iterations_per_second": it_sec,
|
||||
"tokens_per_second": tokens_sec,
|
||||
"trained_tokens": trained_tokens,
|
||||
"peak_memory": peak_mem,
|
||||
})
|
||||
|
||||
losses = 0
|
||||
rewards = mx.zeros((2,))
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save model weights if needed
|
||||
if it % args.steps_per_save == 0:
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
checkpoint = (
|
||||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||
)
|
||||
mx.save_safetensors(str(checkpoint), adapter_weights)
|
||||
print(
|
||||
f"Iter {it}: Saved adapter weights to "
|
||||
f"{args.adapter_file} and {checkpoint}."
|
||||
)
|
||||
|
||||
# Save final weights
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
print(f"Saved final weights to {args.adapter_file}.")
|
Loading…
Reference in New Issue
Block a user