Merge branch 'ml-explore:main' into adding-dpo-training

This commit is contained in:
Gökdeniz Gülmez
2025-02-12 11:09:58 +01:00
committed by GitHub
8 changed files with 89 additions and 26 deletions

View File

@@ -233,8 +233,8 @@ def train(
n_tokens = 0
steps = 0
trained_tokens = 0
train_time = 0
# Main training loop
start = time.perf_counter()
for it, batch in zip(
range(1, args.iters + 1),
iterate_batches(
@@ -245,10 +245,11 @@ def train(
train=True,
),
):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss
# is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter()
tic = time.perf_counter()
val_loss = evaluate(
model=model,
dataset=val_dataset,
@@ -259,7 +260,7 @@ def train(
max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches,
)
val_time = time.perf_counter() - stop
val_time = time.perf_counter() - tic
if rank == 0:
print(
f"Iter {it}: "
@@ -276,24 +277,23 @@ def train(
}
training_callback.on_val_loss_report(val_info)
start = time.perf_counter()
tic = time.perf_counter()
lvalue, toks = step(batch)
losses += lvalue
n_tokens += toks
steps += 1
mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start)
tokens_sec = float(n_tokens) / (stop - start)
it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0:
@@ -322,7 +322,7 @@ def train(
losses = 0
n_tokens = 0
steps = 0
start = time.perf_counter()
train_time = 0
# Save adapter weights
if it % args.steps_per_save == 0:

View File

@@ -89,6 +89,7 @@ def linear_to_lora_layers(
"mixtral",
"nemotron",
"stablelm",
"hunyuan",
"qwen2",
"qwen2_moe",
"phimoe",