mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
fix lora timings after validation (#1278)
This commit is contained in:
parent
f8cbf159e0
commit
42413c5d85
@ -233,8 +233,8 @@ def train(
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
trained_tokens = 0
|
||||
train_time = 0
|
||||
# Main training loop
|
||||
start = time.perf_counter()
|
||||
for it, batch in zip(
|
||||
range(1, args.iters + 1),
|
||||
iterate_batches(
|
||||
@ -245,10 +245,11 @@ def train(
|
||||
train=True,
|
||||
),
|
||||
):
|
||||
tic = time.perf_counter()
|
||||
# Report validation loss if needed, the first validation loss
|
||||
# is always measured before any training.
|
||||
if it == 1 or it % args.steps_per_eval == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
tic = time.perf_counter()
|
||||
val_loss = evaluate(
|
||||
model=model,
|
||||
dataset=val_dataset,
|
||||
@ -259,7 +260,7 @@ def train(
|
||||
max_seq_length=args.max_seq_length,
|
||||
iterate_batches=iterate_batches,
|
||||
)
|
||||
val_time = time.perf_counter() - stop
|
||||
val_time = time.perf_counter() - tic
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Iter {it}: "
|
||||
@ -276,24 +277,23 @@ def train(
|
||||
}
|
||||
training_callback.on_val_loss_report(val_info)
|
||||
|
||||
start = time.perf_counter()
|
||||
tic = time.perf_counter()
|
||||
|
||||
lvalue, toks = step(batch)
|
||||
losses += lvalue
|
||||
n_tokens += toks
|
||||
steps += 1
|
||||
mx.eval(state, losses, n_tokens)
|
||||
train_time += time.perf_counter() - tic
|
||||
|
||||
# Report training loss if needed
|
||||
if it % args.steps_per_report == 0 or it == args.iters:
|
||||
stop = time.perf_counter()
|
||||
|
||||
train_loss = mx.distributed.all_sum(losses, stream=mx.cpu).item()
|
||||
train_loss /= steps * mx.distributed.init().size()
|
||||
n_tokens = mx.distributed.all_sum(n_tokens, stream=mx.cpu).item()
|
||||
learning_rate = optimizer.learning_rate.item()
|
||||
it_sec = args.steps_per_report / (stop - start)
|
||||
tokens_sec = float(n_tokens) / (stop - start)
|
||||
it_sec = args.steps_per_report / train_time
|
||||
tokens_sec = float(n_tokens) / train_time
|
||||
trained_tokens += n_tokens
|
||||
peak_mem = mx.metal.get_peak_memory() / 1e9
|
||||
if rank == 0:
|
||||
@ -322,7 +322,7 @@ def train(
|
||||
losses = 0
|
||||
n_tokens = 0
|
||||
steps = 0
|
||||
start = time.perf_counter()
|
||||
train_time = 0
|
||||
|
||||
# Save adapter weights
|
||||
if it % args.steps_per_save == 0:
|
||||
|
Loading…
Reference in New Issue
Block a user