mlx-examples/llms/mlx_lm/tuner/grpo_trainer.py
2025-03-09 00:16:40 +01:00

823 lines
28 KiB
Python

# Copyright © 2024 Apple Inc.
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from mlx.utils import tree_flatten
from ..models import cache
from ..utils import generation_stream
from .grpo_reward_functions import (
RewardFunctions,
r1_accuracy_reward_func,
r1_count_xml,
r1_extract_xml_answer,
r1_int_reward_func,
r1_soft_format_reward_func,
r1_strict_format_reward_func,
)
from .trainer import TrainingArgs, TrainingCallback, average_gradients, grad_checkpoint
@dataclass
class GRPOTrainingArgs(TrainingArgs):
group_size: int = field(
default=4,
metadata={"help": "Number of responses per prompt."},
)
beta: float = field(default=0.1, metadata={"help": "KL penalty coefficient."})
epsilon: float = field(
default=1e-4, metadata={"help": "The Epsilon for numerical stability."}
)
max_completion_length: int = field(
default=512, metadata={"help": "Number of Generations."}
)
reference_model_path: str = field(
default=None,
metadata={
"help": "Path to reference model weights. If None, uses the same model."
},
)
temperature: float = field(
default=1.0,
metadata={
"help": "Temperature for sampling. The higher the temperature, the more random the completions."
},
)
reward_weights: Optional[List[float]] = field(
default=None,
metadata={
"help": "Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are weighted equally with weight `1.0`."
},
)
def get_per_token_logps(model: nn.Module, inputs, lengths):
logits = model(inputs).astype(mx.float16)
logits = logits[:, :-1, :]
targets = inputs[:, 1:]
per_token_logps = []
for i in range(logits.shape[0]):
seq_len = int(lengths[i]) - 1
seq_logits = logits[i, :seq_len]
seq_targets = targets[i, :seq_len]
log_probs = nn.log_softmax(seq_logits, axis=-1)
token_log_probs = mx.take_along_axis(
log_probs, seq_targets.reshape(seq_len, 1), axis=-1
).squeeze(-1)
per_token_logps.append(token_log_probs)
mx.eval(logits)
return per_token_logps
def generate_step(
prompt: mx.array,
model: nn.Module,
max_tokens: int = 256,
sampler: Optional[Callable] = None,
logits_processors: Optional[List[Callable]] = None,
max_kv_size: Optional[int] = None,
prompt_cache: Optional[Any] = None,
) -> Generator[Tuple[mx.array, mx.array], None, None]:
tokens = None
y = prompt
if prompt_cache is None:
prompt_cache = cache.make_prompt_cache(model, max_kv_size=max_kv_size)
def _step(y):
with mx.stream(generation_stream):
logits = model(y[None], cache=prompt_cache)
logits = logits[:, -1, :]
if logits_processors:
nonlocal tokens
tokens = mx.concat([tokens, y]) if tokens is not None else y
for processor in logits_processors:
logits = processor(tokens, logits)
logprobs = logits - mx.logsumexp(logits, keepdims=True)
next_token = sampler(logprobs)
return mx.stop_gradient(next_token), mx.stop_gradient(logprobs.squeeze(0))
try:
with mx.stream(generation_stream):
y, logprobs = _step(y)
mx.eval(y, logprobs)
for n in range(max_tokens):
yield y.item(), logprobs
next_y, next_logprobs = _step(y)
mx.eval(next_y, next_logprobs)
y, logprobs = next_y, next_logprobs
if (n + 1) % 32 == 0:
mx.metal.clear_cache()
finally:
mx.metal.clear_cache()
def generate_grpo(
model: nn.Module,
tokenizer,
prompt_tokens,
max_tokens: int,
group_size: int,
end_token: str = "</answer>",
temperature: float = 0.8,
batch_size: int = 1,
):
try:
end_sequence = mx.array(tokenizer.encode(end_token))
total_samples = len(prompt_tokens)
all_completions = []
all_completion_texts = []
batch_indices = []
def temp_sampler(logits):
return mx.random.categorical(logits / temperature)
for i in range(0, total_samples, batch_size):
current_batch_size = min(batch_size, total_samples - i)
batch_prompts = prompt_tokens[i : i + current_batch_size]
max_prompt_len = max(len(p) for p in batch_prompts)
padded_prompts = []
for prompt in batch_prompts:
padding = [tokenizer.pad_token_id] * (max_prompt_len - len(prompt))
padded_prompts.append(prompt + padding)
prompt_tensor = mx.stop_gradient(mx.array(padded_prompts))
if len(prompt_tensor.shape) == 1:
prompt_tensor = prompt_tensor[None, :]
if prompt_tensor.shape[1] == 0:
continue
expanded_prompts = mx.repeat(prompt_tensor, group_size, axis=0)
batch_results = []
total_prompt_samples = expanded_prompts.shape[0]
for prompt_idx in range(total_prompt_samples):
current_tokens = []
prompt_cache = cache.make_prompt_cache(model)
for token, _ in generate_step(
expanded_prompts[prompt_idx],
model,
max_tokens=max_tokens,
sampler=temp_sampler,
prompt_cache=prompt_cache,
):
if token == tokenizer.eos_token_id:
break
current_tokens.append(token)
if len(current_tokens) >= len(end_sequence) and mx.array_equal(
mx.array(current_tokens[-len(end_sequence):]), end_sequence
):
break
if current_tokens:
batch_results.append(mx.array(current_tokens))
if batch_results:
for j, completion_ids in enumerate(batch_results):
prompt_idx = i + (j // group_size)
if prompt_idx < total_samples:
batch_indices.append(prompt_idx)
completion_text = tokenizer.decode(completion_ids.tolist())
all_completions.append(mx.stop_gradient(completion_ids))
all_completion_texts.append(completion_text)
mx.metal.clear_cache()
finally:
mx.metal.clear_cache()
return all_completions, all_completion_texts, batch_indices
def grpo_loss(
model,
ref_model,
tokenizer,
batch,
completions=None,
completion_texts=None,
batch_indices=None,
reward_funcs: Optional[List[RewardFunctions]] = None,
beta: float = 0.1,
group_size: int = 4,
epsilon: float = 1e-4,
max_tokens: int = 64,
temperature: float = 0.8,
reward_weights: Optional[List[float]] = None,
batch_size: int = 1,
is_validation: bool = False
):
prompt_tokens, _, prompt_text, answer_text = batch
if completions is not None and completion_texts is not None and batch_indices is not None:
all_completions = completions
all_completion_texts = completion_texts
batch_indices = batch_indices
else:
all_completions, all_completion_texts, batch_indices = generate_grpo(
model=model,
tokenizer=tokenizer,
prompt_tokens=prompt_tokens,
max_tokens=max_tokens,
group_size=group_size,
temperature=temperature,
batch_size=batch_size
)
if not all_completions:
raise ValueError(
"No completions were generated. Please check your model and inputs."
)
expanded_answers = []
expanded_prompts = []
unique_prompt_indices = sorted(set(batch_indices))
grouped_completions = {idx: [] for idx in unique_prompt_indices}
for i, completion_idx in enumerate(batch_indices):
grouped_completions[completion_idx].append(i)
ordered_completions = []
ordered_completion_texts = []
ordered_batch_indices = []
for prompt_idx in unique_prompt_indices:
completion_indices = grouped_completions[prompt_idx]
for idx in completion_indices:
ordered_completions.append(all_completions[idx])
ordered_completion_texts.append(all_completion_texts[idx])
ordered_batch_indices.append(prompt_idx)
expanded_prompts.append(prompt_text[prompt_idx])
expanded_answers.append(answer_text[prompt_idx])
all_completions = ordered_completions
all_completion_texts = ordered_completion_texts
batch_indices = ordered_batch_indices
max_length = max(ids.shape[0] for ids in all_completions)
padded_completions = []
attention_masks = []
for completion_ids in all_completions:
completion_tensor = mx.array(completion_ids.tolist())
padding_length = max_length - completion_tensor.shape[0]
if padding_length > 0:
padding = mx.zeros((padding_length,), dtype=completion_tensor.dtype)
padded_ids = mx.concatenate([completion_tensor, padding])
mask = mx.concatenate(
[mx.ones_like(completion_tensor), mx.zeros_like(padding)]
)
else:
padded_ids = completion_tensor
mask = mx.ones_like(completion_tensor)
padded_completions.append(padded_ids)
attention_masks.append(mask)
inputs = mx.stack(padded_completions)
attention_mask = mx.stack(attention_masks)
lengths = attention_mask.sum(axis=1)
token_log_probs = get_per_token_logps(model, inputs, lengths)
mx.eval(token_log_probs)
if ref_model is None:
ref_token_log_probs = token_log_probs
else:
ref_token_log_probs = get_per_token_logps(ref_model, inputs, lengths)
mx.eval(ref_token_log_probs)
max_len = max(x.shape[0] for x in token_log_probs)
padded_log_probs = []
padded_ref_log_probs = []
for i in range(len(token_log_probs)):
seq_len = token_log_probs[i].shape[0]
padding = mx.zeros((max_len - seq_len,))
padded_log_probs.append(mx.concatenate([token_log_probs[i], padding]))
padded_ref_log_probs.append(mx.concatenate([ref_token_log_probs[i], padding]))
token_log_probs = mx.stack(padded_log_probs)
ref_token_log_probs = mx.stack(padded_ref_log_probs)
all_func_rewards = []
for reward_func in reward_funcs:
func_rewards = mx.array(
reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers,
)
)
all_func_rewards.append(func_rewards)
rewards = mx.stack(all_func_rewards, axis=1)
if reward_weights is not None:
if len(reward_weights) != len(reward_funcs):
raise ValueError(
f"Number of reward weights ({len(reward_weights)}) must match number of reward "
f"functions ({len(reward_funcs)})"
)
reward_weights = mx.array(reward_weights, dtype=mx.float32)
else:
reward_weights = mx.ones(len(reward_funcs), dtype=mx.float32)
rewards = (rewards * mx.expand_dims(reward_weights, 0)).sum(axis=1)
num_unique_prompts = len(unique_prompt_indices)
rewards_by_prompt = [[] for _ in range(num_unique_prompts)]
for i, prompt_idx in enumerate(batch_indices):
prompt_position = unique_prompt_indices.index(prompt_idx)
rewards_by_prompt[prompt_position].append(rewards[i])
advantages = mx.zeros_like(rewards)
for i, prompt_rewards in enumerate(rewards_by_prompt):
if len(prompt_rewards) > 1:
prompt_rewards = mx.array(prompt_rewards)
mean_reward = mx.mean(prompt_rewards)
std_reward = mx.std(prompt_rewards)
indices = [
j
for j, idx in enumerate(batch_indices)
if idx == unique_prompt_indices[i]
]
for j, idx in enumerate(indices):
advantages[idx] = (prompt_rewards[j] - mean_reward) / (
std_reward + epsilon
)
else:
idx = batch_indices.index(unique_prompt_indices[i])
advantages[idx] = 0.0
# Compute KL divergence using Schulman's approximator
kl_div = (
mx.exp(ref_token_log_probs - token_log_probs)
- (ref_token_log_probs - token_log_probs)
- 1
)
# Create mask for valid tokens
length_mask = mx.arange(inputs.shape[1] - 1)[None, :] < (lengths[:, None] - 1)
# Compute policy ratio
policy_ratio = mx.exp(
mx.array(token_log_probs - mx.stop_gradient(ref_token_log_probs))
)
# Compute per-token loss
per_token_loss = -(
(policy_ratio * advantages.reshape(-1, 1) - beta * kl_div) * length_mask
)
# Average over tokens
sequence_sums = per_token_loss.sum(axis=1)
sequence_lengths = length_mask.sum(axis=1)
loss = (sequence_sums / sequence_lengths).mean()
# Calculate mean KL divergence for metrics
mean_kl = ((kl_div * length_mask).sum(axis=1) / length_mask.sum(axis=1)).mean()
# Collect reward metrics
reward_metrics = {}
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
func_rewards = mx.array(
reward_func(
prompts=expanded_prompts,
completions=all_completion_texts,
answer=expanded_answers,
)
)
reward_metrics[f"{func_name}_mean"] = mx.mean(func_rewards)
reward_metrics[f"{func_name}_std"] = mx.std(func_rewards)
grouped_rewards_mean = mx.array(
[mx.mean(mx.array(rewards)) for rewards in rewards_by_prompt]
)
grouped_rewards_std = mx.array(
[
mx.std(mx.array(rewards)) if len(rewards) > 1 else mx.zeros(1)
for rewards in rewards_by_prompt
]
)
metrics = {
"total_rewards_mean": mx.mean(rewards),
"total_rewards_std": mx.std(rewards),
"grouped_rewards_mean": mx.mean(grouped_rewards_mean),
"grouped_rewards_std": mx.mean(grouped_rewards_std),
"kl": mean_kl,
**reward_metrics,
}
if is_validation and all_completion_texts:
print("\n=== Validation Sample Details ===")
# Print the input context (prompt)
last_prompt_idx = batch_indices[-1] if batch_indices else 0
if last_prompt_idx < len(prompt_text):
print(f"\n📋 Raw Prompt:\n{prompt_text[last_prompt_idx]}")
print("\n" + "=" * 10 + "\n")
# Get the actual tokenized prompt that was fed to the model
if last_prompt_idx < len(prompt_tokens):
actual_prompt = tokenizer.decode(prompt_tokens[last_prompt_idx])
print(f"\n🔄 Model Input:\n{actual_prompt}")
print("\n" + "=" * 10 + "\n")
print(f"\n📝 Generation:\n{all_completion_texts[-1]}")
print("\n" + "=" * 10 + "\n")
# Make sure we have a valid index for answer_text
if last_prompt_idx < len(answer_text):
print(f"\n✅ Answer:\n{answer_text[last_prompt_idx]}")
print("\n" + "=" * 10 + "\n")
# Only try to extract if r1_extract_xml_answer is defined
if "r1_extract_xml_answer" in globals():
print(
f"\n🔍 Extracted Answer:\n{r1_extract_xml_answer(all_completion_texts[-1])}"
)
print("\n" + "=" * 35 + "\n")
mx.metal.clear_cache()
return loss, sequence_lengths.sum(), metrics
def iterate_grpo_batches(dataset, batch_size, max_seq_length, train=False):
if not dataset or not isinstance(dataset[0], tuple) or len(dataset[0]) != 4:
raise ValueError(
"Dataset must be list of (prompt_tokens, answer_tokens, prompt_str, answer_str) tuples"
)
def length_key(i):
return len(dataset[i][0]) + len(dataset[i][1])
idx = sorted(range(len(dataset)), key=length_key)
if len(dataset) < batch_size:
raise ValueError(
f"Dataset must have at least batch_size={batch_size} "
f"examples but only has {len(dataset)}."
)
step = mx.distributed.init().size()
if batch_size % step != 0:
raise ValueError("The batch size must be divisible by the number of workers")
def batch_index_generator():
for i in range(0, len(idx) - batch_size + 1, batch_size):
yield idx[i : i + batch_size : step]
while True:
indices = (
np.random.permutation(list(batch_index_generator()))
if train
else batch_index_generator()
)
for batch_idx in indices:
current_batch = [dataset[j] for j in batch_idx]
prompts_tokens = [item[0] for item in current_batch]
answers_tokens = [item[1] for item in current_batch]
prompts_text = [item[2] for item in current_batch]
answers_text = [item[3] for item in current_batch]
if any(len(p) > max_seq_length for p in prompts_tokens):
print(
f"[WARNING] Some prompts are longer than {max_seq_length} tokens. "
"Long prompts will be truncated."
)
yield prompts_tokens, answers_tokens, prompts_text, answers_text
if not train:
break
def evaluate_grpo(
model: nn.Module,
ref_model: Optional[nn.Module],
dataset,
tokenizer,
batch_size,
num_batches,
beta: float,
epsilon: float,
group_size: int,
max_seq_length: int,
max_tokens: int,
temperature: float,
reward_funcs: Optional[List[RewardFunctions]] = [
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml,
],
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches,
):
all_losses = 0
ntokens = 0
all_metrics = None
index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1)
for _, batch in zip(
index_iterator,
iterate_batches(
dataset=dataset,
batch_size=batch_size,
max_seq_length=max_seq_length,
),
):
losses, toks, metrics = loss_fn(
model=model,
tokenizer=tokenizer,
batch=batch,
reward_funcs=reward_funcs,
beta=beta,
group_size=group_size,
epsilon=epsilon,
ref_model=ref_model,
temperature=temperature,
max_tokens=max_tokens,
is_validation=True
)
all_losses += losses * toks
ntokens += toks
if all_metrics is None:
all_metrics = {k: v * toks for k, v in metrics.items()}
else:
for k, v in metrics.items():
all_metrics[k] += v * toks
mx.eval(all_losses, ntokens)
all_losses = mx.distributed.all_sum(all_losses, stream=mx.cpu)
ntokens = mx.distributed.all_sum(ntokens, stream=mx.cpu)
all_metrics = {k: mx.distributed.all_sum(v) for k, v in all_metrics.items()}
avg_metrics = {k: (v / ntokens).item() for k, v in all_metrics.items()}
avg_loss = (all_losses / ntokens).item()
return avg_loss, ntokens, avg_metrics
def train_grpo(
model: nn.Module,
ref_model: Optional[nn.Module],
tokenizer,
optimizer,
train_dataset,
val_dataset,
reward_funcs: Optional[List[RewardFunctions]] = [
r1_accuracy_reward_func,
r1_int_reward_func,
r1_strict_format_reward_func,
r1_soft_format_reward_func,
r1_count_xml,
],
args: GRPOTrainingArgs = GRPOTrainingArgs(),
loss_fn: callable = grpo_loss,
iterate_batches: callable = iterate_grpo_batches,
training_callback: TrainingCallback = None,
):
print(
f"Starting GRPO training with {len(reward_funcs)} reward functions..., iters: {args.iters}"
)
world = mx.distributed.init()
world_size = world.size()
rank = world.rank()
if world_size > 1:
print(f"Node {rank} of {world_size}")
if args.grad_checkpoint:
grad_checkpoint(model.layers[0])
state = [model.state, optimizer.state]
def step(batch):
# Extract prompt tokens from the batch
prompt_tokens, targets, prompt_lens, target_lens = batch
# First, generate completions without gradient tracking
# The model will be frozen during this call
all_completions, all_completion_texts, batch_indices = generate_grpo(
model=model,
tokenizer=tokenizer,
prompt_tokens=prompt_tokens,
max_tokens=args.max_completion_length,
group_size=args.group_size,
temperature=args.temperature
)
# Now calculate loss and gradients with pre-generated completions
# We need to update loss_fn to accept these pre-generated completions
(loss, toks, metrics), grad = loss_value_and_grad(
model,
tokenizer=tokenizer,
batch=(prompt_tokens, targets, prompt_lens, target_lens),
completions=all_completions,
completion_texts=all_completion_texts,
batch_indices=batch_indices,
reward_funcs=reward_funcs,
beta=args.beta,
group_size=args.group_size,
epsilon=args.epsilon,
ref_model=ref_model,
)
grad = average_gradients(grad)
optimizer.update(model, grad)
return loss, toks, metrics
loss_value_and_grad = nn.value_and_grad(model, loss_fn)
losses = 0
n_tokens = 0
steps = 0
trained_tokens = 0
accumulated_metrics = {
"total_rewards_mean": 0,
"total_rewards_std": 0,
"grouped_rewards_mean": 0,
"grouped_rewards_std": 0,
"kl": 0,
}
for reward_func in reward_funcs:
func_name = reward_func.__name__
accumulated_metrics[f"{func_name}_mean"] = 0
accumulated_metrics[f"{func_name}_std"] = 0
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_batches(
dataset=train_dataset,
batch_size=args.batch_size,
max_seq_length=args.max_seq_length,
train=True,
),
):
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
val_loss, val_ntokens, val_metrics = evaluate_grpo(
model=model,
dataset=val_dataset,
loss_fn=loss_fn,
ref_model=ref_model,
reward_funcs=reward_funcs,
tokenizer=tokenizer,
group_size=args.group_size,
batch_size=args.batch_size,
num_batches=args.val_batches,
max_seq_length=args.max_seq_length,
max_tokens=args.max_completion_length,
beta=args.beta,
epsilon=args.epsilon,
temperature=args.temperature,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
if rank == 0:
val_metrics_str = (
f"Val loss {val_loss:.3f}, "
f"Val total_rewards_mean {val_metrics['total_rewards_mean']:.3f}, "
f"Val total_rewards_std {val_metrics['total_rewards_std']:.3f}, "
f"Val grouped_rewards_mean {val_metrics['grouped_rewards_mean']:.3f}, "
f"Val grouped_rewards_std {val_metrics['grouped_rewards_std']:.3f}, "
f"Val kl {val_metrics['kl']:.3f}"
)
for i, reward_func in enumerate(reward_funcs):
val_metrics_str += (
f", Val {reward_func.__name__}_mean {val_metrics[f'{reward_func.__name__}_mean']:.3f}, "
f"Val {reward_func.__name__}_std {val_metrics[f'{reward_func.__name__}_std']:.3f}"
)
print(
f"Iter {it}: {val_metrics_str}, " f"Val took {val_time:.3f}s",
flush=True,
)
if training_callback is not None:
training_callback.on_val_loss_report(
{
"iteration": it,
"val_loss": val_loss,
**{f"val_{k}": v for k, v in val_metrics.items()},
"val_time": val_time,
}
)
start = time.perf_counter()
loss, toks, metrics = step(batch)
losses += loss
n_tokens += toks
steps += 1
mx.metal.clear_cache()
for k, v in metrics.items():
accumulated_metrics[k] += v
mx.eval(state, losses, n_tokens)
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
avg_metrics = {
k: v / (steps * world_size) for k, v in accumulated_metrics.items()
}
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
train_metrics_str = (
f"Train loss {train_loss:.3f}, "
f"Total rewards mean {avg_metrics['total_rewards_mean']:.3f}, "
f"Total rewards std {avg_metrics['total_rewards_std']:.3f}, "
f"Grouped rewards mean {avg_metrics['grouped_rewards_mean']:.3f}, "
f"Grouped rewards std {avg_metrics['grouped_rewards_std']:.3f}, "
f"KL {avg_metrics['kl']:.3f}"
)
for i, reward_func in enumerate(reward_funcs):
func_name = reward_func.__name__
train_metrics_str += (
f", {func_name} mean {avg_metrics[f'{func_name}_mean']:.3f}, "
f"{func_name} std {avg_metrics[f'{func_name}_std']:.3f}"
)
print(
f"Iter {it}: {train_metrics_str}, "
f"Learning Rate {learning_rate:.3e}, "
f"It/sec {it_sec:.3f}, "
f"Tokens/sec {tokens_sec:.3f}, "
f"Peak mem {peak_mem:.3f} GB",
flush=True,
)
if training_callback is not None:
training_callback.on_train_loss_report(
{
"iteration": it,
"train_loss": train_loss,
**{f"train_{k}": v for k, v in avg_metrics.items()},
"learning_rate": learning_rate,
"iterations_per_second": it_sec,
"tokens_per_second": tokens_sec,
"trained_tokens": trained_tokens,
"peak_memory": peak_mem,
}
)
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()
if it % args.steps_per_save == 0:
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
checkpoint = (
Path(args.adapter_file).parent / f"{it:07d}_adapters.safetensors"
)
mx.save_safetensors(str(checkpoint), adapter_weights)
print(
f"Iter {it}: Saved adapter weights to "
f"{args.adapter_file} and {checkpoint}."
)
adapter_weights = dict(tree_flatten(model.trainable_parameters()))
mx.save_safetensors(str(args.adapter_file), adapter_weights)
print(f"Saved final weights to {args.adapter_file}.")