From b9eff0d7442d70fb245676d8832579460f49624f Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 13 Jan 2025 22:32:35 -0800 Subject: [PATCH] Improve printing for FLUX distributed training --- flux/dreambooth.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index f82178b9..dae23992 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -261,19 +261,23 @@ if __name__ == "__main__": generate_progress_images(0, flux, args) grads = None - losses = [] + batch_cnt = 0 + total_loss = 0 tic = time.time() for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)): loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) - mx.eval(loss, grads, state) - losses.append(loss.item()) + total_loss = total_loss + loss + batch_cnt += 1 + mx.eval(total_loss, grads, state) - if (i + 1) % 10 == 0: + if (i + 1) % 10 == 0 and mx.distributed.init().rank() == 0: toc = time.time() peak_mem = mx.metal.get_peak_memory() / 1024**3 + total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu) + total_loss = total_loss.item() / batch_cnt print( - f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} " - f"It/s: {10 / (toc - tic):.3f} " + f"Iter: {i + 1} Loss: {total_loss:.3f} " + f"It/s: {batch_cnt / (toc - tic):.3f} " f"Peak mem: {peak_mem:.3f} GB", flush=True, ) @@ -285,7 +289,8 @@ if __name__ == "__main__": save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args) if (i + 1) % 10 == 0: - losses = [] + total_loss = 0 + batch_cnt = 0 tic = time.time() save_adapters("final_adapters.safetensors", flux, args)