mlx-examples/llms/mlx_lm/tuner/dpo_trainer.py

428 lines
15 KiB
Python
Raw Normal View History

2025-01-19 07:19:36 +08:00
# Copyright © 2024 Apple Inc.
import glob
import shutil
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Union
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
2025-01-19 08:13:17 +08:00
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
2025-01-19 07:19:36 +08:00
@dataclass
2025-01-19 08:13:17 +08:00
class DPOTrainingArgs(TrainingArgs):
2025-01-19 07:19:36 +08:00
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for DPO training."}
)
loss_type: str = field(
default="sigmoid",
metadata={
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
}
)
delta: float = field(
default=50.0,
metadata={
"help": "Delta parameter for DPOP loss type."
}
)
reference_model_path: str = field(
default=None,
metadata={
"help": "Path to reference model weights. If None, uses the same model."
}
)
def dpo_loss(
model,
chosen: mx.array,
rejected: mx.array,
chosen_masks: mx.array,
rejected_masks: mx.array,
beta: float,
delta: float,
loss_type: str = "sigmoid",
2025-01-31 07:01:43 +08:00
ref_model=None,
2025-01-19 07:19:36 +08:00
):
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
2025-01-25 05:40:27 +08:00
2025-01-19 07:19:36 +08:00
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)
num_rejected_tokens = rejected_masks.sum(-1)
# Calculate log probabilities for policy model
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
if loss_type == "ipo":
# ipo uses average log probabilities
policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens
policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens
else:
policy_chosen_score = policy_chosen_scores.sum(-1)
policy_rejected_score = policy_rejected_scores.sum(-1)
# Calculate log probabilities for reference model
2025-01-31 07:01:43 +08:00
if ref_model is None:
2025-01-19 07:19:36 +08:00
reference_chosen_score = mx.zeros_like(policy_chosen_score)
reference_rejected_score = mx.zeros_like(policy_rejected_score)
else:
2025-01-31 07:01:43 +08:00
reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks))
2025-01-19 07:19:36 +08:00
if loss_type == "ipo":
# ipo uses average log probabilities
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens
else:
reference_chosen_score = reference_chosen_scores.sum(-1)
reference_rejected_score = reference_rejected_scores.sum(-1)
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
2025-01-25 05:40:27 +08:00
if loss_type == "sigmoid": # From the og paper
2025-01-19 07:19:36 +08:00
losses = -nn.log_sigmoid(beta * logits)
elif loss_type == "hinge":
losses = nn.relu(1 - beta * logits)
elif loss_type == "ipo":
losses = (logits - 1 / (2 * beta)) ** 2
elif loss_type == "dpop":
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
losses = -(nn.log_sigmoid(beta * logits) - delta * penalty)
else:
raise ValueError(f"Unknown loss type: {loss_type}")
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
reward = mx.stack([chosen_reward, rejected_reward])
2025-01-26 22:09:55 +08:00
metrics = {
'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)),
'margins': mx.mean(chosen_reward - rejected_reward),
'policy_rejected_logps': mx.mean(policy_rejected_score / num_rejected_tokens),
'policy_chosen_logps': mx.mean(policy_chosen_score / num_chosen_tokens),
'rejected_logits_mean': mx.mean(policy_rejected_score),
'chosen_logits_mean': mx.mean(policy_chosen_score)
}
return mx.mean(losses), reward, num_tokens, metrics
2025-01-19 07:19:36 +08:00
2025-01-25 05:40:27 +08:00
def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
2025-01-19 07:19:36 +08:00
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
2025-01-25 05:40:27 +08:00
2025-01-19 07:19:36 +08:00
step = mx.distributed.init().size()
if batch_size % step != 0:
2025-01-25 05:40:27 +08:00
raise ValueError("Batch size must be divisible by workers")
batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)]
2025-01-19 07:19:36 +08:00
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
2025-01-25 05:40:27 +08:00
# Get and process lengths
2025-01-19 07:19:36 +08:00
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
2025-01-25 05:40:27 +08:00
max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length)
# Dynamic padding based on batch content
max_length_in_batch = max_length
2025-01-19 07:19:36 +08:00
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
2025-01-25 05:40:27 +08:00
2025-01-19 07:19:36 +08:00
for j in range(batch_size // step):
chosen_length = min(chosen_lengths[j], max_seq_length)
rejected_length = min(rejected_lengths[j], max_seq_length)
2025-01-25 05:40:27 +08:00
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
2025-01-19 07:19:36 +08:00
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
2025-01-25 05:40:27 +08:00
chosen_masks[j, :chosen_length] = 1.0
2025-01-19 07:19:36 +08:00
rejected_masks[j, :rejected_length] = 1.0
2025-01-25 05:40:27 +08:00
yield mx.array(chosen_arr), mx.array(rejected_arr), mx.array(chosen_masks), mx.array(rejected_masks)
2025-01-19 07:19:36 +08:00
if not train:
break
def evaluate_dpo(
model,
2025-01-31 07:01:43 +08:00
ref_model,
2025-01-19 07:19:36 +08:00
dataset,
batch_size,
num_batches,
beta: float,
delta: float,
2025-01-26 05:03:32 +08:00
max_seq_length,
loss_type,
2025-02-01 04:27:59 +08:00
loss: callable = dpo_loss
2025-01-19 07:19:36 +08:00
):
all_losses = 0
2025-01-26 05:03:32 +08:00
all_rewards = mx.zeros((2,))
2025-01-26 22:09:55 +08:00
all_metrics = None
2025-01-19 07:19:36 +08:00
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_dpo_batches(
dataset=dataset,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks = batch
2025-01-26 05:03:32 +08:00
2025-02-01 04:27:59 +08:00
loss, reward, toks, metrics = loss(
2025-01-19 07:19:36 +08:00
model=model,
2025-01-31 07:01:43 +08:00
ref_model=ref_model,
2025-01-19 07:19:36 +08:00
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
loss_type=loss_type,
beta=beta,
delta=delta,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
2025-01-26 22:09:55 +08:00
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
2025-02-01 00:19:55 +08:00
mx.eval(all_losses, all_rewards, ntokens)
2025-01-19 07:19:36 +08:00
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
2025-01-26 05:03:32 +08:00
ntokens = mx.distributed.all_sum(ntokens)
2025-01-26 22:09:55 +08:00
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_rewards = (all_rewards / ntokens).tolist()
avg_loss = (all_losses / ntokens).item()
2025-01-26 05:03:32 +08:00
2025-01-26 22:09:55 +08:00
return avg_loss, avg_rewards, ntokens, avg_metrics
2025-01-19 07:19:36 +08:00
2025-01-26 05:03:32 +08:00
2025-01-19 07:19:36 +08:00
def train_dpo(
model,
2025-01-31 07:01:43 +08:00
ref_model,
2025-01-19 07:19:36 +08:00
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: DPOTrainingArgs = DPOTrainingArgs(),
2025-02-01 04:27:59 +08:00
loss: callable = dpo_loss,
2025-01-19 07:19:36 +08:00
training_callback: TrainingCallback = None,
loss_type="sigmoid",
):
print(f"Starting DPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
chosen, rejected, chosen_masks, rejected_masks = batch
2025-02-01 04:27:59 +08:00
(lvalue, reward, toks, metrics), grad = loss_value_and_grad(
2025-01-19 07:19:36 +08:00
model,
2025-01-31 07:01:43 +08:00
ref_model,
2025-01-19 07:19:36 +08:00
chosen,
rejected,
chosen_masks,
rejected_masks
)
grad = average_gradients(grad)
optimizer.update(model, grad)
2025-02-01 04:27:59 +08:00
return lvalue, reward, toks, metrics
2025-01-19 07:19:36 +08:00
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
2025-02-01 04:27:59 +08:00
return loss(
2025-01-19 07:19:36 +08:00
model=model,
reference_teacher_model=ref_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
beta=args.beta,
delta=args.delta,
2025-01-31 07:01:43 +08:00
loss_type=loss_type
2025-01-19 07:19:36 +08:00
)
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
trained_tokens = 0
2025-01-26 22:09:55 +08:00
accumulated_metrics = {
'accuracies': 0,
'margins': 0,
'policy_rejected_logps': 0,
'policy_chosen_logps': 0,
'rejected_logits_mean': 0,
'chosen_logits_mean': 0
}
2025-01-19 07:19:36 +08:00
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_dpo_batches(
dataset=train_dataset,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
2025-01-26 22:09:55 +08:00
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
2025-01-19 07:19:36 +08:00
model=model,
2025-01-31 07:01:43 +08:00
reference_model=ref_model,
2025-01-19 07:19:36 +08:00
dataset=val_dataset,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
2025-02-01 04:27:59 +08:00
loss=loss,
2025-01-19 07:19:36 +08:00
beta=args.beta,
delta=args.delta,
loss_type=loss_type,
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
2025-02-01 04:37:15 +08:00
f"Val loss {val_loss:.3f}, "
2025-01-19 07:19:36 +08:00
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
2025-01-26 22:09:55 +08:00
f"Val accuracy {val_metrics['accuracies']:.3f}, "
f"Val margin {val_metrics['margins']:.3f}, "
2025-01-19 07:19:36 +08:00
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
2025-01-26 22:09:55 +08:00
training_callback.on_val_loss_report({
2025-01-19 07:19:36 +08:00
"iteration": it,
"val_loss": val_loss,
"val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1],
2025-01-26 22:09:55 +08:00
**{f"val_{k}": v for k, v in val_metrics.items()},
2025-01-19 07:19:36 +08:00
"val_time": val_time,
2025-01-26 22:09:55 +08:00
})
2025-01-19 07:19:36 +08:00
start = time.perf_counter()
2025-02-01 04:27:59 +08:00
lvalue, reward, toks, metrics = step(batch)
losses += lvalue
2025-01-19 07:19:36 +08:00
rewards += reward
n_tokens += toks
steps += 1
2025-02-01 04:27:59 +08:00
2025-01-26 22:09:55 +08:00
for k, v in metrics.items():
accumulated_metrics[k] += v
2025-02-01 04:27:59 +08:00
2025-01-19 07:19:36 +08:00
mx.eval(state, losses, rewards, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
2025-01-26 22:09:55 +08:00
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
2025-01-19 07:19:36 +08:00
train_rewards = mx.distributed.all_sum(rewards).tolist()
train_rewards = [r / (steps * world_size) for r in train_rewards]
2025-01-26 22:09:55 +08:00
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
2025-01-19 07:19:36 +08:00
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
2025-02-01 04:37:15 +08:00
f"Iter {it}: Train loss {train_loss:.3f}, "
2025-01-19 07:19:36 +08:00
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
2025-01-26 22:09:55 +08:00
f"Accuracy {avg_metrics['accuracies']:.3f}, "
f"Margin {avg_metrics['margins']:.3f}, "
2025-01-19 07:19:36 +08:00
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
train_info = {
"iteration": it,
"train_loss": train_loss,
"train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1],
2025-01-26 22:09:55 +08:00
**{f"train_{k}": v for k, v in avg_metrics.items()},
2025-01-19 07:19:36 +08:00
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")