This commit is contained in:
Goekdeniz-Guelmez 2025-01-19 01:58:29 +01:00
parent 7d279b51ef
commit fa80d081f2
4 changed files with 372 additions and 188 deletions

View File

@ -20,6 +20,7 @@ LoRA (QLoRA).[^qlora] LoRA fine-tuning works with the following model families:
- [Run](#Run)
- [Fine-tune](#Fine-tune)
- [DPO Training](#DPO Training)
- [ORPO Training](#ORPO Training)
- [Evaluate](#Evaluate)
- [Generate](#Generate)
- [Fuse](#Fuse)
@ -105,6 +106,38 @@ For DPO training, the data should be in JSONL format with the following structur
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
```
Here's the equivalent ORPO documentation:
### ORPO Training
Offline Reward Policy Optimization (ORPO) training allows you to fine-tune models using human preference data with pre-computed rewards. To use ORPO training, set the training mode to 'orpo':
```shell
mlx_lm.lora \
--model <path_to_model> \
--train \
--training-mode orpo \
--data <path_to_data> \
--beta 0.1 \
--reward-scaling 1.0
```
The ORPO training accepts the following additional parameters:
- `--beta`: Controls the temperature parameter for the logistic function (default: 0.1)
- `--reward-scaling`: Scaling factor for the offline rewards (default: 1.0)
For ORPO training, the data should be in JSONL format with the following structure:
```jsonl
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response"}
```
The training process will automatically assign binary rewards (1.0 for chosen and 0.0 for rejected responses) if no explicit rewards are provided. You can also provide custom rewards in your data:
```jsonl
{"prompt": "User prompt", "chosen": "Preferred response", "rejected": "Less preferred response", "chosen_reward": 0.8, "rejected_reward": 0.3}
```
### Evaluate
To compute test set perplexity use:

View File

@ -16,6 +16,7 @@ from .tokenizer_utils import TokenizerWrapper
from .tuner.datasets import load_dataset
from .tuner.trainer import TrainingArgs, TrainingCallback, evaluate, train
from .tuner.dpo_trainer import DPOTrainingArgs, evaluate_dpo, train_dpo
from .tuner.orpo_trainer import ORPOTrainingArgs, evaluate_orpo, train_orpo
from .tuner.utils import (
build_schedule,
linear_to_lora_layers,
@ -70,6 +71,7 @@ CONFIG_DEFAULTS = {
"delta": 50.0,
"reference_model_path": None,
"train_bias_only": False,
"reward_scaling": 1.0,
}
@ -106,7 +108,7 @@ def build_parser():
"--training-mode",
type=str,
choices=["normal", "dpo", "orpo"],
help="Training mode: normal, DPO or ORPO",
help="Training mode: normal, DPO or ORPO.",
)
parser.add_argument(
"--num-layers",
@ -149,7 +151,7 @@ def build_parser():
parser.add_argument(
"--test",
action="store_true",
help="Evaluate on the test set after training",
help="Evaluate on the test set after training.",
default=None,
)
parser.add_argument(
@ -166,7 +168,7 @@ def build_parser():
"-c",
"--config",
type=str,
help="A YAML configuration file with the training options",
help="A YAML configuration file with the training options.",
)
parser.add_argument(
"--grad-checkpoint",
@ -180,7 +182,8 @@ def build_parser():
parser.add_argument("--delta", type=float)
parser.add_argument("--reference-model-path", type=str)
parser.add_argument("--train-bias-only", action="store_true")
parser.add_argument("--seed", type=int, help="The PRNG seed")
parser.add_argument("--reward-scaling", type=float, help="Scaling factor for offline rewards.")
parser.add_argument("--seed", type=int, help="The PRNG seed.")
return parser
@ -226,7 +229,8 @@ def train_model(
build_schedule(args.lr_schedule) if args.lr_schedule else args.learning_rate
)
)
# Train model
# Train model based on training mode
if args.training_mode == "dpo":
training_args = DPOTrainingArgs(
batch_size=args.batch_size,
@ -261,6 +265,32 @@ def train_model(
args=training_args,
training_callback=training_callback,
)
elif args.training_mode == "orpo":
training_args = ORPOTrainingArgs(
batch_size=args.batch_size,
iters=args.iters,
val_batches=args.val_batches,
steps_per_report=args.steps_per_report,
steps_per_eval=args.steps_per_eval,
steps_per_save=args.save_every,
adapter_file=adapter_file,
max_seq_length=args.max_seq_length,
grad_checkpoint=args.grad_checkpoint,
beta=args.beta,
reward_scaling=args.reward_scaling,
train_bias_only=args.train_bias_only,
seed=args.seed,
)
train_orpo(
model=model,
tokenizer=tokenizer,
optimizer=opt,
train_dataset=train_set,
val_dataset=valid_set,
args=training_args,
training_callback=training_callback,
)
else:
training_args = TrainingArgs(
batch_size=args.batch_size,
@ -304,7 +334,19 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
max_seq_length=args.max_seq_length,
beta=args.beta,
delta=args.delta,
loss_type=args.loss_type,
loss_type=args.dpo_loss_type,
)
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
elif args.training_mode == "orpo":
test_loss, test_rewards = evaluate_orpo(
model=model,
dataset=test_set,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.test_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
reward_scaling=args.reward_scaling,
)
print(f"Test loss {test_loss:.3f}, Rewards: {test_rewards[0]:.3f}, {test_rewards[1]:.3f}")
else:
@ -318,7 +360,6 @@ def evaluate_model(args, model: nn.Module, tokenizer: TokenizerWrapper, test_set
)
test_ppl = math.exp(test_loss)
print(f"Test loss {test_loss:.3f}, Test ppl {test_ppl:.3f}.")

View File

@ -12,7 +12,6 @@ import mlx.nn as nn
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten
from ..generate import generate
from .trainer import TrainingCallback, grad_checkpoint, TrainingArgs

View File

@ -1,54 +1,48 @@
# Copyright © 2024 Apple Inc.
import time
from pathlib import Path
from dataclasses import dataclass, field
import mlx.core as mx
import mlx.nn as nn
from dpo_trainer import DPOTrainingArgs, iterate_dpo_batches, train_dpo, TrainingCallback
import mlx.nn as nn
import mlx.core as mx
import numpy as np
from mlx.utils import tree_flatten
from mlx.nn.utils import average_gradients
from .dpo_trainer import DPOTrainingArgs, grad_checkpoint
@dataclass
class ORPOTrainingArgs(DPOTrainingArgs):
"""
Training arguments specific to ORPO, extending DPO arguments.
"""
mu: float = field(
default=0.5,
metadata={"help": "ORPO KL divergence weight parameter"}
reward_scaling: float = field(
default=1.0,
metadata={"help": "Scaling factor for offline rewards."}
)
def orpo_loss(
model,
reference_teacher_model,
chosen: mx.array,
rejected: mx.array,
chosen_masks: mx.array,
rejected_masks: mx.array,
chosen_rewards: mx.array, # Pre-computed rewards
rejected_rewards: mx.array, # Pre-computed rewards
beta: float,
delta: float,
mu: float = 0.5, # ORPO hyperparameter for balancing KL divergence
loss_type: str = "sigmoid",
is_reference_free: bool = False
reward_scaling: float = 1.0,
):
"""
Calculate ORPO loss for inputs.
ORPO extends DPO by adding a KL regularization term to prevent overfitting to preferences.
Calculate ORPO loss using pre-computed rewards.
Args:
model: Policy model
reference_teacher_model: Reference model
chosen: Chosen sequence tokens
rejected: Rejected sequence tokens
chosen_masks: Masks for chosen sequences
rejected_masks: Masks for rejected sequences
chosen_masks: Attention masks for chosen sequences
rejected_masks: Attention masks for rejected sequences
chosen_rewards: Pre-computed rewards for chosen sequences
rejected_rewards: Pre-computed rewards for rejected sequences
beta: Temperature parameter
delta: Margin for DPOP loss type
mu: ORPO hyperparameter for balancing KL divergence (default: 0.5)
loss_type: Loss type ('sigmoid', 'hinge', 'ipo', or 'dpop')
is_reference_free: Whether to use reference-free training
reward_scaling: Scaling factor for rewards
Returns:
Tuple of (loss, reward, num_tokens)
Loss value, rewards, and number of tokens.
"""
def make_predictions(model, x, mask):
inputs = x[:, :-1]
@ -59,218 +53,335 @@ def orpo_loss(
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
num_chosen_tokens = chosen_masks.sum(-1)
num_rejected_tokens = rejected_masks.sum(-1)
# Calculate log probabilities for policy model
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
# Calculate reference model scores
if not is_reference_free:
reference_chosen_scores = mx.stop_gradient(make_predictions(reference_teacher_model, chosen, chosen_masks))
reference_rejected_scores = mx.stop_gradient(make_predictions(reference_teacher_model, rejected, rejected_masks))
else:
reference_chosen_scores = mx.zeros_like(policy_chosen_scores)
reference_rejected_scores = mx.zeros_like(policy_rejected_scores)
# Compute average log probabilities if using IPO loss
if loss_type == "ipo":
policy_chosen_score = policy_chosen_scores.sum(-1) / num_chosen_tokens
policy_rejected_score = policy_rejected_scores.sum(-1) / num_rejected_tokens
reference_chosen_score = reference_chosen_scores.sum(-1) / num_chosen_tokens
reference_rejected_score = reference_rejected_scores.sum(-1) / num_rejected_tokens
else:
policy_chosen_score = policy_chosen_scores.sum(-1)
policy_rejected_score = policy_rejected_scores.sum(-1)
reference_chosen_score = reference_chosen_scores.sum(-1)
reference_rejected_score = reference_rejected_scores.sum(-1)
# Calculate preference logits
logits = (policy_chosen_score - policy_rejected_score) - (reference_chosen_score - reference_rejected_score)
# Calculate preference loss based on loss type
if loss_type == "sigmoid":
preference_loss = -nn.log_sigmoid(beta * logits)
elif loss_type == "hinge":
preference_loss = nn.relu(1 - beta * logits)
elif loss_type == "ipo":
preference_loss = (logits - 1 / (2 * beta)) ** 2
elif loss_type == "dpop":
penalty = mx.maximum(mx.zeros_like(policy_chosen_score), reference_chosen_score - policy_chosen_score)
preference_loss = -(nn.log_sigmoid(beta * logits) - delta * penalty)
else:
raise ValueError(f"Unknown loss type: {loss_type}")
# Calculate KL divergence term for ORPO
kl_div_chosen = mx.mean((policy_chosen_scores - reference_chosen_scores) ** 2)
kl_div_rejected = mx.mean((policy_rejected_scores - reference_rejected_scores) ** 2)
kl_regularization = mu * (kl_div_chosen + kl_div_rejected)
# Combine preference loss and KL regularization
loss = mx.mean(preference_loss) + kl_regularization
# Scale the pre-computed rewards
chosen_rewards = chosen_rewards * reward_scaling
rejected_rewards = rejected_rewards * reward_scaling
num_tokens = (num_chosen_tokens + num_rejected_tokens).sum()
# Calculate rewards for monitoring
chosen_reward = beta * mx.mean(policy_chosen_score - reference_chosen_score)
rejected_reward = beta * mx.mean(policy_rejected_score - reference_rejected_score)
reward = mx.stack([chosen_reward, rejected_reward])
# ORPO uses the reward difference directly
reward_diff = chosen_rewards - rejected_rewards
# Calculate ORPO loss using logistic function
policy_diff = policy_chosen_scores.sum(-1) - policy_rejected_scores.sum(-1)
loss = -nn.log_sigmoid(beta * (policy_diff * reward_diff))
loss = mx.mean(loss)
# Calculate number of tokens for logging
num_tokens = (chosen_masks.sum() + rejected_masks.sum())
# Calculate rewards for logging
avg_chosen_reward = mx.mean(chosen_rewards)
avg_rejected_reward = mx.mean(rejected_rewards)
reward = mx.stack([avg_chosen_reward, avg_rejected_reward])
return loss, reward, num_tokens
def evaluate_orpo(
model,
reference_model,
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
delta: float,
mu: float = 0.5,
reward_scaling: float = 1.0,
max_seq_length=2048,
loss_type="sigmoid",
is_reference_free=False,
):
"""
Evaluate model using ORPO metrics.
Args:
model: Policy model to evaluate
reference_model: Reference model for comparison
dataset: Evaluation dataset
tokenizer: Tokenizer for processing text
batch_size: Batch size for evaluation
num_batches: Number of batches to evaluate (-1 for full dataset)
beta: Temperature parameter
delta: Margin for DPOP loss
mu: ORPO KL divergence weight
max_seq_length: Maximum sequence length
loss_type: Type of loss function
is_reference_free: Whether to use reference-free evaluation
Returns:
Tuple of (loss, rewards, kl_metrics), where:
- loss is the total ORPO loss
- rewards is [chosen_reward, rejected_reward]
- kl_metrics is [chosen_kl, rejected_kl]
Evaluation function for ORPO.
"""
all_losses = 0
all_rewards = mx.zeros((2,)) # [chosen_reward, rejected_reward]
all_kl_divs = mx.zeros((2,)) # [chosen_kl, rejected_kl]
all_rewards = mx.zeros((2,))
ntokens = 0
def compute_kl_divergence(policy_scores, reference_scores, masks):
"""Helper function to compute KL divergence metrics."""
# Using MSE as a proxy for KL divergence as in the loss function
valid_tokens = masks.sum()
kl_div = ((policy_scores - reference_scores) ** 2 * masks).sum() / valid_tokens
return kl_div
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_dpo_batches( # Reusing DPO batch iterator
iterate_orpo_batches(
dataset=dataset,
tokenizer=tokenizer,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
chosen, rejected, chosen_masks, rejected_masks = batch
# Get model predictions
def make_predictions(model, x, mask):
inputs = x[:, :-1]
targets = x[:, 1:]
logits = model(inputs)
logits = logits.astype(mx.float32)
return -nn.losses.cross_entropy(logits, targets) * mask[:, :-1]
# Get scores for both models
policy_chosen_scores = make_predictions(model, chosen, chosen_masks)
policy_rejected_scores = make_predictions(model, rejected, rejected_masks)
if not is_reference_free:
reference_chosen_scores = mx.stop_gradient(
make_predictions(reference_model, chosen, chosen_masks)
)
reference_rejected_scores = mx.stop_gradient(
make_predictions(reference_model, rejected, rejected_masks)
)
else:
reference_chosen_scores = mx.zeros_like(policy_chosen_scores)
reference_rejected_scores = mx.zeros_like(policy_rejected_scores)
# Compute KL divergences
chosen_kl = compute_kl_divergence(
policy_chosen_scores, reference_chosen_scores, chosen_masks[:, :-1]
)
rejected_kl = compute_kl_divergence(
policy_rejected_scores, reference_rejected_scores, rejected_masks[:, :-1]
)
all_kl_divs += mx.stack([chosen_kl, rejected_kl])
# Compute ORPO loss and rewards
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
loss, reward, toks = orpo_loss(
model=model,
reference_teacher_model=reference_model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=beta,
delta=delta,
mu=mu,
loss_type=loss_type,
is_reference_free=is_reference_free,
reward_scaling=reward_scaling,
)
all_losses += loss * toks
all_rewards += reward
ntokens += toks
mx.eval(all_losses, all_rewards, all_kl_divs, ntokens)
mx.eval(all_losses, all_rewards, ntokens)
# Aggregate metrics across distributed workers if necessary
all_losses = mx.distributed.all_sum(all_losses)
all_rewards = mx.distributed.all_sum(all_rewards)
all_kl_divs = mx.distributed.all_sum(all_kl_divs)
ntokens = mx.distributed.all_sum(ntokens)
# Normalize metrics
avg_loss = (all_losses / ntokens).item()
avg_rewards = [r / mx.distributed.init().size() for r in all_rewards.tolist()]
avg_kl_divs = [kl / mx.distributed.init().size() for kl in all_kl_divs.tolist()]
return (all_losses / ntokens).item(), all_rewards.tolist()
return avg_loss, avg_rewards, avg_kl_divs
def iterate_orpo_batches(dataset, tokenizer, batch_size, max_seq_length, train=False):
"""
Modified batch iterator for ORPO that includes pre-computed rewards.
Works with pre-tokenized input data.
"""
# Sort pairs by length of the chosen response
idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx]['chosen']))
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size}"
f" examples but only has {len(dataset)}."
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
batch_idx = [
idx[i : i + batch_size : step]
for i in range(0, len(idx) - batch_size + 1, batch_size)
]
while True:
indices = np.random.permutation(len(batch_idx)) if train else range(len(batch_idx))
for i in indices:
batch = [dataset[j] for j in batch_idx[i]]
# Get lengths assuming data is already tokenized
chosen_lengths = [len(x['chosen']) for x in batch]
rejected_lengths = [len(x['rejected']) for x in batch]
max_length = max(max(chosen_lengths), max(rejected_lengths))
if max_length > max_seq_length:
print(
f"[WARNING] Sequences longer than {max_seq_length} tokens "
f"will be truncated."
)
pad_to = 8
max_length_in_batch = pad_to * ((max_length + pad_to - 1) // pad_to)
max_length_in_batch = min(max_length_in_batch, max_seq_length)
chosen_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
rejected_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32)
chosen_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
rejected_masks = np.zeros((batch_size // step, max_length_in_batch), np.float32)
# Always use binary rewards
chosen_rewards = np.ones((batch_size // step,), np.float32)
rejected_rewards = np.zeros((batch_size // step,), np.float32)
for j in range(batch_size // step):
# Use pre-tokenized sequences directly
chosen_length = min(chosen_lengths[j], max_seq_length)
chosen_arr[j, :chosen_length] = batch[j]['chosen'][:chosen_length]
chosen_masks[j, :chosen_length] = 1.0
rejected_length = min(rejected_lengths[j], max_seq_length)
rejected_arr[j, :rejected_length] = batch[j]['rejected'][:rejected_length]
rejected_masks[j, :rejected_length] = 1.0
yield (mx.array(chosen_arr), mx.array(rejected_arr),
mx.array(chosen_masks), mx.array(rejected_masks),
mx.array(chosen_rewards), mx.array(rejected_rewards))
if not train:
break
def train_orpo(
model,
reference_model,
tokenizer,
optimizer,
train_dataset,
val_dataset,
args: ORPOTrainingArgs = ORPOTrainingArgs(),
training_callback: TrainingCallback = None,
training_callback = None,
):
"""
Train a model using ORPO (Offline Rejection Preference Optimization).
This function adapts the DPO training loop to use ORPO loss.
Training function for ORPO.
"""
return train_dpo(
model=model,
reference_model=reference_model,
tokenizer=tokenizer,
optimizer=optimizer,
train_dataset=train_dataset,
val_dataset=val_dataset,
args=args,
loss_fn=orpo_loss,
training_callback=training_callback,
loss_type=args.loss_type,
)
print(f"Starting ORPO training..., iters: {args.iters}")
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
chosen, rejected, chosen_masks, rejected_masks, chosen_rewards, rejected_rewards = batch
(loss, reward, toks), grad = loss_value_and_grad(
model,
chosen,
rejected,
chosen_masks,
rejected_masks,
chosen_rewards,
rejected_rewards
)
grad = average_gradients(grad)
optimizer.update(model, grad)
return loss, reward, toks
def loss_wrapper(model, chosen, rejected, chosen_masks, rejected_masks,
chosen_rewards, rejected_rewards):
return orpo_loss(
model=model,
chosen=chosen,
rejected=rejected,
chosen_masks=chosen_masks,
rejected_masks=rejected_masks,
chosen_rewards=chosen_rewards,
rejected_rewards=rejected_rewards,
beta=args.beta,
reward_scaling=args.reward_scaling
)
loss_value_and_grad = nn.value_and_grad(model, loss_wrapper)
# Training loop with progress tracking
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
trained_tokens = 0
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_orpo_batches( # reuse DPO batch iterator
dataset=train_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
# Evaluate if needed
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_rewards = evaluate_orpo(
model=model,
dataset=val_dataset,
tokenizer=tokenizer,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
beta=args.beta,
reward_scaling=args.reward_scaling,
)
val_time = time.perf_counter() - stop
if rank == 0:
print(
f"Iter {it}: "
f"Val loss {val_loss:.3f}, "
f"Val chosen reward {val_rewards[0]:.3f}, "
f"Val rejected reward {val_rewards[1]:.3f}, "
f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
training_callback.on_val_loss_report({
"iteration": it,
"val_loss": val_loss,
"val_chosen_reward": val_rewards[0],
"val_rejected_reward": val_rewards[1],
"val_time": val_time,
})
start = time.perf_counter()
# Training step
loss, reward, toks = step(batch)
losses += loss
rewards += reward
n_tokens += toks
steps += 1
mx.eval(state, losses, rewards, n_tokens)
# Report training metrics if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses).item() / (steps * world_size)
train_rewards = [r / (steps * world_size) for r in mx.distributed.all_sum(rewards).tolist()]
n_tokens = mx.distributed.all_sum(n_tokens).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
print(
f"Iter {it}: Train loss {train_loss:.3f}, "
f"Chosen reward {train_rewards[0]:.3f}, "
f"Rejected reward {train_rewards[1]:.3f}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Trained Tokens {trained_tokens}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
training_callback.on_train_loss_report({
"iteration": it,
"train_loss": train_loss,
"train_chosen_reward": train_rewards[0],
"train_rejected_reward": train_rewards[1],
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
})
losses = 0
rewards = mx.zeros((2,))
n_tokens = 0
steps = 0
start = time.perf_counter()
# Save model weights if needed
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
# Save final weights
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")