Improve printing for FLUX distributed training

This commit is contained in:
Angelos Katharopoulos 2025-01-13 22:32:35 -08:00
parent c117af83b8
commit b9eff0d744

View File

@ -261,19 +261,23 @@ if __name__ == "__main__":
generate_progress_images(0, flux, args)
grads = None
losses = []
batch_cnt = 0
total_loss = 0
tic = time.time()
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state)
losses.append(loss.item())
total_loss = total_loss + loss
batch_cnt += 1
mx.eval(total_loss, grads, state)
if (i + 1) % 10 == 0:
if (i + 1) % 10 == 0 and mx.distributed.init().rank() == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu)
total_loss = total_loss.item() / batch_cnt
print(
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "
f"Iter: {i + 1} Loss: {total_loss:.3f} "
f"It/s: {batch_cnt / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB",
flush=True,
)
@ -285,7 +289,8 @@ if __name__ == "__main__":
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
if (i + 1) % 10 == 0:
losses = []
total_loss = 0
batch_cnt = 0
tic = time.time()
save_adapters("final_adapters.safetensors", flux, args)