fix lora timings after validation (#1278)

This commit is contained in:
Awni Hannun 2025-02-11 16:48:55 -08:00 committed by GitHub
parent f8cbf159e0
commit 42413c5d85
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -233,8 +233,8 @@ def train(
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
trained_tokens = 0 trained_tokens = 0
train_time = 0
# Main training loop # Main training loop
start = time.perf_counter()
for it, batch in zip( for it, batch in zip(
range(1, args.iters + 1), range(1, args.iters + 1),
iterate_batches( iterate_batches(
@ -245,10 +245,11 @@ def train(
train=True, train=True,
), ),
): ):
tic = time.perf_counter()
# Report validation loss if needed, the first validation loss # Report validation loss if needed, the first validation loss
# is always measured before any training. # is always measured before any training.
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
stop = time.perf_counter() tic = time.perf_counter()
val_loss = evaluate( val_loss = evaluate(
model=model, model=model,
dataset=val_dataset, dataset=val_dataset,
@ -259,7 +260,7 @@ def train(
max_seq_length=args.max_seq_length, max_seq_length=args.max_seq_length,
iterate_batches=iterate_batches, iterate_batches=iterate_batches,
) )
val_time = time.perf_counter() - stop val_time = time.perf_counter() - tic
if rank == 0: if rank == 0:
print( print(
f"Iter {it}: " f"Iter {it}: "
@ -276,24 +277,23 @@ def train(
} }
training_callback.on_val_loss_report(val_info) training_callback.on_val_loss_report(val_info)
start = time.perf_counter() tic = time.perf_counter()
lvalue, toks = step(batch) lvalue, toks = step(batch)
losses += lvalue losses += lvalue
n_tokens += toks n_tokens += toks
steps += 1 steps += 1
mx.eval(state, losses, n_tokens) mx.eval(state, losses, n_tokens)
train_time += time.perf_counter() - tic
# Report training loss if needed # Report training loss if needed
if it % args.steps_per_report == 0 or it == args.iters: if it % args.steps_per_report == 0 or it == args.iters:
stop = time.perf_counter()
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item() train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
train_loss /= steps * mx.distributed.init().size() train_loss /= steps * mx.distributed.init().size()
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item() n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
learning_rate = optimizer.learning_rate.item() learning_rate = optimizer.learning_rate.item()
it_sec = args.steps_per_report / (stop - start) it_sec = args.steps_per_report / train_time
tokens_sec = float(n_tokens) / (stop - start) tokens_sec = float(n_tokens) / train_time
trained_tokens += n_tokens trained_tokens += n_tokens
peak_mem = mx.metal.get_peak_memory() / 1e9 peak_mem = mx.metal.get_peak_memory() / 1e9
if rank == 0: if rank == 0:
@ -322,7 +322,7 @@ def train(
losses = 0 losses = 0
n_tokens = 0 n_tokens = 0
steps = 0 steps = 0
start = time.perf_counter() train_time = 0
# Save adapter weights # Save adapter weights
if it % args.steps_per_save == 0: if it % args.steps_per_save == 0: