mlx-examples/llms/mlx_lm/tuner/dpo_trainer.py

425 lines
15 KiB
Python
Raw Normal View History

2025-01-19 07:19:36 +08:00
# Copyright © 2024 Apple Inc.
import time
from dataclasses import dataclass, field
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
2025-01-19 08:13:17 +08:00
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
2025-01-19 07:19:36 +08:00
@dataclass
2025-01-19 08:13:17 +08:00
class DPOTrainingArgs(TrainingArgs):
2025-01-19 07:19:36 +08:00
beta: float = field(
default=0.1,
metadata={"help": "Temperature parameter for DPO training."}
)
loss_type: str = field(
default="sigmoid",
metadata={
"help": "DPO loss type: 'sigmoid', 'hinge', 'ipo', or 'dpop'."
}
)
delta: float = field(
default=50.0,
metadata={
"help": "Delta parameter for DPOP loss type."
}
)
reference_model_path: str = field(
default=None,
metadata={
"help": "Path to reference model weights. If None, uses the same model."
}
)
def dpo_loss(
model,
chosen: mx.array,
rejected: mx.array,
chosen_masks: mx.array,
rejected_masks: mx.array,
beta: float,
delta: float,
loss_type: str = "sigmoid",
2025-01-31 07:01:43 +08:00
ref_model=None,
2025-01-19 07:19:36 +08:00
):
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
2025-01-25 05:40:27 +08:00
2025-01-19 07:19:36 +08:00
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)
num_rejected_tokens = rejected_masks.sum(-1)
# Calculate log probabilities for policy model
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
if loss_type == "ipo":
# ipo uses average log probabilities
policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens
policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens
else:
policy_chosen_score = policy_chosen_scores.sum(-1)
policy_rejected_score = policy_rejected_scores.sum(-1)
# Calculate log probabilities for reference model
2025-01-31 07:01:43 +08:00
if ref_model is None:
2025-01-19 07:19:36 +08:00
reference_chosen_score = mx.zeros_like(policy_chosen_score)
reference_rejected_score = mx.zeros_like(policy_rejected_score)
else:
2025-01-31 07:01:43 +08:00
reference_chosen_scores = mx.stop_gradient(make_predictions(ref_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(ref_model, rejected, rejected_masks))
2025-01-19 07:19:36 +08:00
if loss_type == "ipo":
# ipo uses average log probabilities
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens
else:
reference_chosen_score = reference_chosen_scores.sum(-1)
reference_rejected_score = reference_rejected_scores.sum(-1)
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
2025-01-25 05:40:27 +08:00
if loss_type == "sigmoid": # From the og paper
2025-01-19 07:19:36 +08:00
losses = -nn.log_sigmoid(beta * logits)
elif loss_type == "hinge":
losses = nn.relu(1 - beta * logits)
elif loss_type == "ipo":
losses = (logits - 1 / (2 * beta)) ** 2
elif loss_type == "dpop":
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
losses = -(nn.log_sigmoid(beta * logits) - delta * penalty)
else:
raise ValueError(f"Unknown loss type: {loss_type}")
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
reward = mx.stack([chosen_reward, rejected_reward])
2025-01-26 22:09:55 +08:00
metrics = {
'accuracies': mx.mean((chosen_reward > rejected_reward).astype(mx.float32)),
'margins': mx.mean(chosen_reward - rejected_reward),
'policy_rejected_logps': mx.mean(policy_rejected_score / num_rejected_tokens),
'policy_chosen_logps': mx.mean(policy_chosen_score / num_chosen_tokens),
'rejected_logits_mean': mx.mean(policy_rejected_score),
'chosen_logits_mean': mx.mean(policy_chosen_score)
}
return mx.mean(losses), reward, num_tokens, metrics
2025-01-19 07:19:36 +08:00
2025-01-25 05:40:27 +08:00
def iterate_dpo_batches(dataset, batch_size, max_seq_length, train=False):
2025-01-19 07:19:36 +08:00
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
2025-01-25 05:40:27 +08:00
2025-01-19 07:19:36 +08:00
step = mx.distributed.init().size()
if batch_size % step != 0:
2025-01-25 05:40:27 +08:00
raise ValueError("Batch size must be divisible by workers")
batch_idx = [idx[i:i+batch_size:step] for i in range(0, len(idx)-batch_size+1, batch_size)]
2025-01-19 07:19:36 +08:00
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
2025-01-25 05:40:27 +08:00
# Get and process lengths
2025-01-19 07:19:36 +08:00
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
2025-01-25 05:40:27 +08:00
max_length = min(max(max(chosen_lengths), max(rejected_lengths)), max_seq_length)
# Dynamic padding based on batch content
max_length_in_batch = max_length
2025-01-19 07:19:36 +08:00
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
2025-01-25 05:40:27 +08:00
2025-01-19 07:19:36 +08:00
for j in range(batch_size // step):
chosen_length = min(chosen_lengths[j], max_seq_length)
rejected_length = min(rejected_lengths[j], max_seq_length)
2025-01-25 05:40:27 +08:00
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
2025-01-19 07:19:36 +08:00
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
2025-01-25 05:40:27 +08:00
chosen_masks[j, :chosen_length] = 1.0
2025-01-19 07:19:36 +08:00
rejected_masks[j, :rejected_length] = 1.0
2025-01-25 05:40:27 +08:00
yield mx.array(chosen_arr), mx.array(rejected_arr), mx.array(chosen_masks), mx.array(rejected_masks)
2025-01-19 07:19:36 +08:00
if not train:
break
def evaluate_dpo(
model,
2025-01-31 07:01:43 +08:00
ref_model,
2025-01-19 07:19:36 +08:00
dataset,
batch_size,
num_batches,
beta: float,
delta: float,
2025-01-26 05:03:32 +08:00
max_seq_length,
loss_type,
2025-02-01 04:27:59 +08:00
loss: callable = dpo_loss
2025-01-19 07:19:36 +08:00
):
all_losses = 0
2025-01-26 05:03:32 +08:00
all_rewards = mx.zeros((2,))
2025-01-26 22:09:55 +08:00
all_metrics = None
2025-01-19 07:19:36 +08:00
ntokens = 0
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_dpo_batches(
dataset=dataset,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks = batch
2025-01-26 05:03:32 +08:00
2025-02-01 04:27:59 +08:00
loss, reward, toks, metrics = loss(
2025-01-19 07:19:36 +08:00
model=model,
2025-01-31 07:01:43 +08:00
ref_model=ref_model,
2025-01-19 07:19:36 +08:00
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
loss_type=loss_type,
beta=beta,
delta=delta,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
2025-01-26 22:09:55 +08:00
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
2025-02-01 00:19:55 +08:00
mx.eval(all_losses, all_rewards, ntokens)
2025-01-19 07:19:36 +08:00
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
2025-01-26 05:03:32 +08:00
ntokens = mx.distributed.all_sum(ntokens)
2025-01-26 22:09:55 +08:00
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_rewards = (all_rewards / ntokens).tolist()
avg_loss = (all_losses / ntokens).item()
2025-01-26 05:03:32 +08:00
2025-01-26 22:09:55 +08:00
return avg_loss, avg_rewards, ntokens, avg_metrics
2025-01-19 07:19:36 +08:00
2025-01-26 05:03:32 +08:00
2025-01-19 07:19:36 +08:00
def train_dpo(
model,
2025-01-31 07:01:43 +08:00
ref_model,
2025-01-19 07:19:36 +08:00
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: DPOTrainingArgs = DPOTrainingArgs(),
2025-02-01 04:27:59 +08:00
loss: callable = dpo_loss,
2025-01-19 07:19:36 +08:00
training_callback: TrainingCallback = None,
loss_type="sigmoid",
):
print(f"Starting DPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
chosen, rejected, chosen_masks, rejected_masks = batch
2025-02-01 04:27:59 +08:00
(lvalue, reward, toks, metrics), grad = loss_value_and_grad(
2025-01-19 07:19:36 +08:00
model,
2025-01-31 07:01:43 +08:00
ref_model,
2025-01-19 07:19:36 +08:00
chosen,
rejected,
chosen_masks,
rejected_masks
)
grad = average_gradients(grad)
optimizer.update(model, grad)
2025-02-01 04:27:59 +08:00
return lvalue, reward, toks, metrics
2025-01-19 07:19:36 +08:00
def loss_wrapper(model, ref_model, chosen, rejected, chosen_masks, rejected_masks):
2025-02-01 04:27:59 +08:00
return loss(
2025-01-19 07:19:36 +08:00
model=model,
2025-02-01 23:08:52 +08:00
ref_model=ref_model,
2025-01-19 07:19:36 +08:00
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
beta=args.beta,
delta=args.delta,
2025-01-31 07:01:43 +08:00
loss_type=loss_type
2025-01-19 07:19:36 +08:00
)
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
trained_tokens = 0
2025-01-26 22:09:55 +08:00
accumulated_metrics = {
'accuracies': 0,
'margins': 0,
'policy_rejected_logps': 0,
'policy_chosen_logps': 0,
'rejected_logits_mean': 0,
'chosen_logits_mean': 0
}
2025-01-19 07:19:36 +08:00
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_dpo_batches(
dataset=train_dataset,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
2025-01-26 22:09:55 +08:00
val_loss, val_rewards, val_ntokens, val_metrics = evaluate_dpo(
2025-01-19 07:19:36 +08:00
model=model,
2025-02-01 23:08:52 +08:00
ref_model=ref_model,
2025-01-19 07:19:36 +08:00
dataset=val_dataset,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
2025-02-01 04:27:59 +08:00
loss=loss,
2025-01-19 07:19:36 +08:00
beta=args.beta,
delta=args.delta,
loss_type=loss_type,
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
2025-02-01 04:37:15 +08:00
f"Val loss {val_loss:.3f}, "
2025-01-19 07:19:36 +08:00
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
2025-01-26 22:09:55 +08:00
f"Val accuracy {val_metrics['accuracies']:.3f}, "
f"Val margin {val_metrics['margins']:.3f}, "
2025-01-19 07:19:36 +08:00
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
2025-01-26 22:09:55 +08:00
training_callback.on_val_loss_report({
2025-01-19 07:19:36 +08:00
"iteration": it,
"val_loss": val_loss,
"val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1],
2025-01-26 22:09:55 +08:00
**{f"val_{k}": v for k, v in val_metrics.items()},
2025-01-19 07:19:36 +08:00
"val_time": val_time,
2025-01-26 22:09:55 +08:00
})
2025-01-19 07:19:36 +08:00
start = time.perf_counter()
2025-02-01 04:27:59 +08:00
lvalue, reward, toks, metrics = step(batch)
losses += lvalue
2025-01-19 07:19:36 +08:00
rewards += reward
n_tokens += toks
steps += 1
2025-02-01 04:27:59 +08:00
2025-01-26 22:09:55 +08:00
for k, v in metrics.items():
accumulated_metrics[k] += v
2025-02-01 04:27:59 +08:00
2025-01-19 07:19:36 +08:00
mx.eval(state, losses, rewards, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
2025-01-26 22:09:55 +08:00
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
2025-01-19 07:19:36 +08:00
train_rewards = mx.distributed.all_sum(rewards).tolist()
train_rewards = [r / (steps * world_size) for r in train_rewards]
2025-01-26 22:09:55 +08:00
avg_metrics = {k: v / (steps * world_size) for k, v in accumulated_metrics.items()}
2025-01-19 07:19:36 +08:00
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
2025-02-01 04:37:15 +08:00
f"Iter {it}: Train loss {train_loss:.3f}, "
2025-01-19 07:19:36 +08:00
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
2025-01-26 22:09:55 +08:00
f"Accuracy {avg_metrics['accuracies']:.3f}, "
f"Margin {avg_metrics['margins']:.3f}, "
2025-01-19 07:19:36 +08:00
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
train_info = {
"iteration": it,
"train_loss": train_loss,
"train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1],
2025-01-26 22:09:55 +08:00
**{f"train_{k}": v for k, v in avg_metrics.items()},
2025-01-19 07:19:36 +08:00
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
training_callback.on_train_loss_report(train_info)
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
start = time.perf_counter()
# Save adapter weights
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")