mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-15 23:58:08 +08:00
finish
This commit is contained in:
@@ -12,7 +12,6 @@ import mlx.nn as nn
|
||||
import numpy as np
|
||||
from mlx.nn.utils import average_gradients
|
||||
from mlx.utils import tree_flatten
|
||||
from ..generate import generate
|
||||
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs
|
||||
|
||||
|
||||
|
@@ -1,54 +1,48 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import time
|
||||
from pathlib import Path
|
||||
from dataclasses import dataclass, field
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from dpo_trainer import DPOTrainingArgs, iterate_dpo_batches, train_dpo, TrainingCallback
|
||||
import mlx.nn as nn
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from mlx.utils import tree_flatten
|
||||
from mlx.nn.utils import average_gradients
|
||||
from .dpo_trainer import DPOTrainingArgs, grad_checkpoint
|
||||
|
||||
|
||||
@dataclass
|
||||
class ORPOTrainingArgs(DPOTrainingArgs):
|
||||
"""
|
||||
Training arguments specific to ORPO, extending DPO arguments.
|
||||
"""
|
||||
mu: float = field(
|
||||
default=0.5,
|
||||
metadata={"help": "ORPO KL divergence weight parameter"}
|
||||
reward_scaling: float = field(
|
||||
default=1.0,
|
||||
metadata={"help": "Scaling factor for offline rewards."}
|
||||
)
|
||||
|
||||
|
||||
def orpo_loss(
|
||||
model,
|
||||
reference_teacher_model,
|
||||
chosen: mx.array,
|
||||
rejected: mx.array,
|
||||
chosen_masks: mx.array,
|
||||
rejected_masks: mx.array,
|
||||
chosen_rewards: mx.array, # Pre-computed rewards
|
||||
rejected_rewards: mx.array, # Pre-computed rewards
|
||||
beta: float,
|
||||
delta: float,
|
||||
mu: float = 0.5, # ORPO hyperparameter for balancing KL divergence
|
||||
loss_type: str = "sigmoid",
|
||||
is_reference_free: bool = False
|
||||
reward_scaling: float = 1.0,
|
||||
):
|
||||
"""
|
||||
Calculate ORPO loss for inputs.
|
||||
ORPO extends DPO by adding a KL regularization term to prevent overfitting to preferences.
|
||||
|
||||
Calculate ORPO loss using pre-computed rewards.
|
||||
Args:
|
||||
model: Policy model
|
||||
reference_teacher_model: Reference model
|
||||
chosen: Chosen sequence tokens
|
||||
rejected: Rejected sequence tokens
|
||||
chosen_masks: Masks for chosen sequences
|
||||
rejected_masks: Masks for rejected sequences
|
||||
chosen_masks: Attention masks for chosen sequences
|
||||
rejected_masks: Attention masks for rejected sequences
|
||||
chosen_rewards: Pre-computed rewards for chosen sequences
|
||||
rejected_rewards: Pre-computed rewards for rejected sequences
|
||||
beta: Temperature parameter
|
||||
delta: Margin for DPOP loss type
|
||||
mu: ORPO hyperparameter for balancing KL divergence (default: 0.5)
|
||||
loss_type: Loss type ('sigmoid', 'hinge', 'ipo', or 'dpop')
|
||||
is_reference_free: Whether to use reference-free training
|
||||
reward_scaling: Scaling factor for rewards
|
||||
Returns:
|
||||
Tuple of (loss, reward, num_tokens)
|
||||
Loss value, rewards, and number of tokens.
|
||||
"""
|
||||
def make_predictions(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
@@ -59,218 +53,335 @@ def orpo_loss(
|
||||
|
||||
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
|
||||
|
||||
num_chosen_tokens = chosen_masks.sum(-1)
|
||||
num_rejected_tokens = rejected_masks.sum(-1)
|
||||
|
||||
# Calculate log probabilities for policy model
|
||||
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
|
||||
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
|
||||
|
||||
# Calculate reference model scores
|
||||
if not is_reference_free:
|
||||
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
|
||||
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
|
||||
else:
|
||||
reference_chosen_scores = mx.zeros_like(policy_chosen_scores)
|
||||
reference_rejected_scores = mx.zeros_like(policy_rejected_scores)
|
||||
|
||||
# Compute average log probabilities if using IPO loss
|
||||
if loss_type == "ipo":
|
||||
policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens
|
||||
policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens
|
||||
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
|
||||
reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens
|
||||
else:
|
||||
policy_chosen_score = policy_chosen_scores.sum(-1)
|
||||
policy_rejected_score = policy_rejected_scores.sum(-1)
|
||||
reference_chosen_score = reference_chosen_scores.sum(-1)
|
||||
reference_rejected_score = reference_rejected_scores.sum(-1)
|
||||
|
||||
# Calculate preference logits
|
||||
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
|
||||
|
||||
# Calculate preference loss based on loss type
|
||||
if loss_type == "sigmoid":
|
||||
preference_loss = -nn.log_sigmoid(beta * logits)
|
||||
elif loss_type == "hinge":
|
||||
preference_loss = nn.relu(1 - beta * logits)
|
||||
elif loss_type == "ipo":
|
||||
preference_loss = (logits - 1 / (2 * beta)) ** 2
|
||||
elif loss_type == "dpop":
|
||||
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
|
||||
preference_loss = -(nn.log_sigmoid(beta * logits) - delta * penalty)
|
||||
else:
|
||||
raise ValueError(f"Unknown loss type: {loss_type}")
|
||||
|
||||
# Calculate KL divergence term for ORPO
|
||||
kl_div_chosen = mx.mean((policy_chosen_scores - reference_chosen_scores) ** 2)
|
||||
kl_div_rejected = mx.mean((policy_rejected_scores - reference_rejected_scores) ** 2)
|
||||
kl_regularization = mu * (kl_div_chosen + kl_div_rejected)
|
||||
|
||||
# Combine preference loss and KL regularization
|
||||
loss = mx.mean(preference_loss) + kl_regularization
|
||||
# Scale the pre-computed rewards
|
||||
chosen_rewards = chosen_rewards * reward_scaling
|
||||
rejected_rewards = rejected_rewards * reward_scaling
|
||||
|
||||
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
|
||||
|
||||
# Calculate rewards for monitoring
|
||||
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
|
||||
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
|
||||
reward = mx.stack([chosen_reward, rejected_reward])
|
||||
# ORPO uses the reward difference directly
|
||||
reward_diff = chosen_rewards - rejected_rewards
|
||||
|
||||
# Calculate ORPO loss using logistic function
|
||||
policy_diff = policy_chosen_scores.sum(-1) - policy_rejected_scores.sum(-1)
|
||||
loss = -nn.log_sigmoid(beta * (policy_diff * reward_diff))
|
||||
|
||||
loss = mx.mean(loss)
|
||||
|
||||
# Calculate number of tokens for logging
|
||||
num_tokens = (chosen_masks.sum() + rejected_masks.sum())
|
||||
|
||||
# Calculate rewards for logging
|
||||
avg_chosen_reward = mx.mean(chosen_rewards)
|
||||
avg_rejected_reward = mx.mean(rejected_rewards)
|
||||
reward = mx.stack([avg_chosen_reward, avg_rejected_reward])
|
||||
|
||||
return loss, reward, num_tokens
|
||||
|
||||
|
||||
def evaluate_orpo(
|
||||
model,
|
||||
reference_model,
|
||||
dataset,
|
||||
tokenizer,
|
||||
batch_size,
|
||||
num_batches,
|
||||
beta: float,
|
||||
delta: float,
|
||||
mu: float = 0.5,
|
||||
reward_scaling: float = 1.0,
|
||||
max_seq_length=2048,
|
||||
loss_type="sigmoid",
|
||||
is_reference_free=False,
|
||||
):
|
||||
"""
|
||||
Evaluate model using ORPO metrics.
|
||||
|
||||
Args:
|
||||
model: Policy model to evaluate
|
||||
reference_model: Reference model for comparison
|
||||
dataset: Evaluation dataset
|
||||
tokenizer: Tokenizer for processing text
|
||||
batch_size: Batch size for evaluation
|
||||
num_batches: Number of batches to evaluate (-1 for full dataset)
|
||||
beta: Temperature parameter
|
||||
delta: Margin for DPOP loss
|
||||
mu: ORPO KL divergence weight
|
||||
max_seq_length: Maximum sequence length
|
||||
loss_type: Type of loss function
|
||||
is_reference_free: Whether to use reference-free evaluation
|
||||
|
||||
Returns:
|
||||
Tuple of (loss, rewards, kl_metrics), where:
|
||||
- loss is the total ORPO loss
|
||||
- rewards is [chosen_reward, rejected_reward]
|
||||
- kl_metrics is [chosen_kl, rejected_kl]
|
||||
Evaluation function for ORPO.
|
||||
"""
|
||||
all_losses = 0
|
||||
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
|
||||
all_kl_divs = mx.zeros((2,)) # [chosen_kl, rejected_kl]
|
||||
all_rewards = mx.zeros((2,))
|
||||
ntokens = 0
|
||||
|
||||
def compute_kl_divergence(policy_scores, reference_scores, masks):
|
||||
"""Helper function to compute KL divergence metrics."""
|
||||
# Using MSE as a proxy for KL divergence as in the loss function
|
||||
valid_tokens = masks.sum()
|
||||
kl_div = ((policy_scores - reference_scores) ** 2 * masks).sum() / valid_tokens
|
||||
return kl_div
|
||||
|
||||
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
|
||||
|
||||
for _, batch in zip(
|
||||
index_iterator,
|
||||
iterate_dpo_batches( # Reusing DPO batch iterator
|
||||
iterate_orpo_batches(
|
||||
dataset=dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=batch_size,
|
||||
max_seq_length=max_seq_length,
|
||||
),
|
||||
):
|
||||
chosen, rejected, chosen_masks, rejected_masks = batch
|
||||
|
||||
# Get model predictions
|
||||
def make_predictions(model, x, mask):
|
||||
inputs = x[:, :-1]
|
||||
targets = x[:, 1:]
|
||||
logits = model(inputs)
|
||||
logits = logits.astype(mx.float32)
|
||||
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
|
||||
|
||||
# Get scores for both models
|
||||
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
|
||||
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
|
||||
|
||||
if not is_reference_free:
|
||||
reference_chosen_scores = mx.stop_gradient(
|
||||
make_predictions(reference_model, chosen, chosen_masks)
|
||||
)
|
||||
reference_rejected_scores = mx.stop_gradient(
|
||||
make_predictions(reference_model, rejected, rejected_masks)
|
||||
)
|
||||
else:
|
||||
reference_chosen_scores = mx.zeros_like(policy_chosen_scores)
|
||||
reference_rejected_scores = mx.zeros_like(policy_rejected_scores)
|
||||
|
||||
# Compute KL divergences
|
||||
chosen_kl = compute_kl_divergence(
|
||||
policy_chosen_scores, reference_chosen_scores, chosen_masks[:, :-1]
|
||||
)
|
||||
rejected_kl = compute_kl_divergence(
|
||||
policy_rejected_scores, reference_rejected_scores, rejected_masks[:, :-1]
|
||||
)
|
||||
all_kl_divs += mx.stack([chosen_kl, rejected_kl])
|
||||
|
||||
# Compute ORPO loss and rewards
|
||||
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
|
||||
loss, reward, toks = orpo_loss(
|
||||
model=model,
|
||||
reference_teacher_model=reference_model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
rejected_masks=rejected_masks,
|
||||
chosen_rewards=chosen_rewards,
|
||||
rejected_rewards=rejected_rewards,
|
||||
beta=beta,
|
||||
delta=delta,
|
||||
mu=mu,
|
||||
loss_type=loss_type,
|
||||
is_reference_free=is_reference_free,
|
||||
reward_scaling=reward_scaling,
|
||||
)
|
||||
|
||||
all_losses += loss * toks
|
||||
all_rewards += reward
|
||||
ntokens += toks
|
||||
mx.eval(all_losses, all_rewards, all_kl_divs, ntokens)
|
||||
mx.eval(all_losses, all_rewards, ntokens)
|
||||
|
||||
# Aggregate metrics across distributed workers if necessary
|
||||
all_losses = mx.distributed.all_sum(all_losses)
|
||||
all_rewards = mx.distributed.all_sum(all_rewards)
|
||||
all_kl_divs = mx.distributed.all_sum(all_kl_divs)
|
||||
ntokens = mx.distributed.all_sum(ntokens)
|
||||
|
||||
# Normalize metrics
|
||||
avg_loss = (all_losses / ntokens).item()
|
||||
avg_rewards = [r / mx.distributed.init().size() for r in all_rewards.tolist()]
|
||||
avg_kl_divs = [kl / mx.distributed.init().size() for kl in all_kl_divs.tolist()]
|
||||
return (all_losses / ntokens).item(), all_rewards.tolist()
|
||||
|
||||
return avg_loss, avg_rewards, avg_kl_divs
|
||||
|
||||
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
|
||||
"""
|
||||
Modified batch iterator for ORPO that includes pre-computed rewards.
|
||||
Works with pre-tokenized input data.
|
||||
"""
|
||||
# Sort pairs by length of the chosen response
|
||||
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
|
||||
|
||||
if len(dataset) < batch_size:
|
||||
raise ValueError(
|
||||
f"Dataset must have at least batch_size={batch_size}"
|
||||
f" examples but only has {len(dataset)}."
|
||||
)
|
||||
|
||||
step = mx.distributed.init().size()
|
||||
if batch_size % step != 0:
|
||||
raise ValueError("The batch size must be divisible by the number of workers")
|
||||
|
||||
batch_idx = [
|
||||
idx[i : i + batch_size : step]
|
||||
for i in range(0, len(idx) - batch_size + 1, batch_size)
|
||||
]
|
||||
|
||||
while True:
|
||||
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
|
||||
for i in indices:
|
||||
batch = [dataset[j] for j in batch_idx[i]]
|
||||
|
||||
# Get lengths assuming data is already tokenized
|
||||
chosen_lengths = [len(x['chosen']) for x in batch]
|
||||
rejected_lengths = [len(x['rejected']) for x in batch]
|
||||
max_length = max(max(chosen_lengths), max(rejected_lengths))
|
||||
|
||||
if max_length > max_seq_length:
|
||||
print(
|
||||
f"[WARNING] Sequences longer than {max_seq_length} tokens "
|
||||
f"will be truncated."
|
||||
)
|
||||
|
||||
pad_to = 8
|
||||
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
|
||||
max_length_in_batch = min(max_length_in_batch, max_seq_length)
|
||||
|
||||
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
|
||||
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
|
||||
|
||||
# Always use binary rewards
|
||||
chosen_rewards = np.ones((batch_size // step,), np.float32)
|
||||
rejected_rewards = np.zeros((batch_size // step,), np.float32)
|
||||
|
||||
for j in range(batch_size // step):
|
||||
# Use pre-tokenized sequences directly
|
||||
chosen_length = min(chosen_lengths[j], max_seq_length)
|
||||
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
|
||||
chosen_masks[j, :chosen_length] = 1.0
|
||||
|
||||
rejected_length = min(rejected_lengths[j], max_seq_length)
|
||||
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
|
||||
rejected_masks[j, :rejected_length] = 1.0
|
||||
|
||||
yield (mx.array(chosen_arr), mx.array(rejected_arr),
|
||||
mx.array(chosen_masks), mx.array(rejected_masks),
|
||||
mx.array(chosen_rewards), mx.array(rejected_rewards))
|
||||
|
||||
if not train:
|
||||
break
|
||||
|
||||
|
||||
def train_orpo(
|
||||
model,
|
||||
reference_model,
|
||||
tokenizer,
|
||||
optimizer,
|
||||
train_dataset,
|
||||
val_dataset,
|
||||
args: ORPOTrainingArgs = ORPOTrainingArgs(),
|
||||
training_callback: TrainingCallback = None,
|
||||
training_callback = None,
|
||||
):
|
||||
"""
|
||||
Train a model using ORPO (Offline Rejection Preference Optimization).
|
||||
This function adapts the DPO training loop to use ORPO loss.
|
||||
Training function for ORPO.
|
||||
"""
|
||||
return train_dpo(
|
||||
model=model,
|
||||
reference_model=reference_model,
|
||||
tokenizer=tokenizer,
|
||||
optimizer=optimizer,
|
||||
train_dataset=train_dataset,
|
||||
val_dataset=val_dataset,
|
||||
args=args,
|
||||
loss_fn=orpo_loss,
|
||||
training_callback=training_callback,
|
||||
loss_type=args.loss_type,
|
||||
)
|
||||
print(f"Starting ORPO training..., iters: {args.iters}")
|
||||
world = mx.distributed.init()
|
||||
world_size = world.size()
|
||||
rank = world.rank()
|
||||
|
||||
if world_size > 1:
|
||||
print(f"Node {rank} of {world_size}")
|
||||
|
||||
if args.grad_checkpoint:
|
||||
grad_checkpoint(model.layers[0])
|
||||
|
||||
state = [model.state, optimizer.state]
|
||||
|
||||
def step(batch):
|
||||
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
|
||||
|
||||
(loss, reward, toks), grad = loss_value_and_grad(
|
||||
model,
|
||||
chosen,
|
||||
rejected,
|
||||
chosen_masks,
|
||||
rejected_masks,
|
||||
chosen_rewards,
|
||||
rejected_rewards
|
||||
)
|
||||
|
||||
grad = average_gradients(grad)
|
||||
optimizer.update(model, grad)
|
||||
|
||||
return loss, reward, toks
|
||||
|
||||
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
|
||||
chosen_rewards, rejected_rewards):
|
||||
return orpo_loss(
|
||||
model=model,
|
||||
chosen=chosen,
|
||||
rejected=rejected,
|
||||
chosen_masks=chosen_masks,
|
||||
rejected_masks=rejected_masks,
|
||||
chosen_rewards=chosen_rewards,
|
||||
rejected_rewards=rejected_rewards,
|
||||
beta=args.beta,
|
||||
reward_scaling=args.reward_scaling
|
||||
)
|
||||
|
||||
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
|
||||
|
||||
# Training loop with progress tracking
|
||||
losses = 0
|
||||
rewards = mx.zeros((2,))
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(1, args.iters + 1),
|
||||
iterate_orpo_batches( # reuse DPO batch iterator
|
||||
dataset=train_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
max_seq_length=args.max_seq_length,
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
# Evaluate if needed
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
val_loss, val_rewards = evaluate_orpo(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
tokenizer=tokenizer,
|
||||
batch_size=args.batch_size,
|
||||
num_batches=args.val_batches,
|
||||
max_seq_length=args.max_seq_length,
|
||||
beta=args.beta,
|
||||
reward_scaling=args.reward_scaling,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: "
|
||||
f"Val loss {val_loss:.3f}, "
|
||||
f"Val chosen reward {val_rewards[0]:.3f}, "
|
||||
f"Val rejected reward {val_rewards[1]:.3f}, "
|
||||
f"Val took {val_time:.3f}s",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_val_loss_report({
|
||||
"iteration": it,
|
||||
"val_loss": val_loss,
|
||||
"val_chosen_reward": val_rewards[0],
|
||||
"val_rejected_reward": val_rewards[1],
|
||||
"val_time": val_time,
|
||||
})
|
||||
|
||||
start = time.perf_counter()
|
||||
|
||||
# Training step
|
||||
loss, reward, toks = step(batch)
|
||||
losses += loss
|
||||
rewards += reward
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, rewards, n_tokens)
|
||||
|
||||
# Report training metrics if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
|
||||
train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()]
|
||||
n_tokens = mx.distributed.all_sum(n_tokens).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: Train loss {train_loss:.3f}, "
|
||||
f"Chosen reward {train_rewards[0]:.3f}, "
|
||||
f"Rejected reward {train_rewards[1]:.3f}, "
|
||||
f"Learning Rate {learning_rate:.3e}, "
|
||||
f"It/sec {it_sec:.3f}, "
|
||||
f"Tokens/sec {tokens_sec:.3f}, "
|
||||
f"Trained Tokens {trained_tokens}, "
|
||||
f"Peak mem {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if training_callback is not None:
|
||||
training_callback.on_train_loss_report({
|
||||
"iteration": it,
|
||||
"train_loss": train_loss,
|
||||
"train_chosen_reward": train_rewards[0],
|
||||
"train_rejected_reward": train_rewards[1],
|
||||
"learning_rate": learning_rate,
|
||||
"iterations_per_second": it_sec,
|
||||
"tokens_per_second": tokens_sec,
|
||||
"trained_tokens": trained_tokens,
|
||||
"peak_memory": peak_mem,
|
||||
})
|
||||
|
||||
losses = 0
|
||||
rewards = mx.zeros((2,))
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
start = time.perf_counter()
|
||||
|
||||
# Save model weights if needed
|
||||
if it % args.steps_per_save == 0:
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
checkpoint = (
|
||||
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
|
||||
)
|
||||
mx.save_safetensors(str(checkpoint), adapter_weights)
|
||||
print(
|
||||
f"Iter {it}: Saved adapter weights to "
|
||||
f"{args.adapter_file} and {checkpoint}."
|
||||
)
|
||||
|
||||
# Save final weights
|
||||
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
|
||||
mx.save_safetensors(str(args.adapter_file), adapter_weights)
|
||||
print(f"Saved final weights to {args.adapter_file}.")
|
Reference in New Issue
Block a user