mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Improve printing for FLUX distributed training
This commit is contained in:
parent
c117af83b8
commit
b9eff0d744
@ -261,19 +261,23 @@ if __name__ == "__main__":
|
||||
generate_progress_images(0, flux, args)
|
||||
|
||||
grads = None
|
||||
losses = []
|
||||
batch_cnt = 0
|
||||
total_loss = 0
|
||||
tic = time.time()
|
||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||
mx.eval(loss, grads, state)
|
||||
losses.append(loss.item())
|
||||
total_loss = total_loss + loss
|
||||
batch_cnt += 1
|
||||
mx.eval(total_loss, grads, state)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
if (i + 1) % 10 == 0 and mx.distributed.init().rank() == 0:
|
||||
toc = time.time()
|
||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||
total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu)
|
||||
total_loss = total_loss.item() / batch_cnt
|
||||
print(
|
||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"It/s: {10 / (toc - tic):.3f} "
|
||||
f"Iter: {i + 1} Loss: {total_loss:.3f} "
|
||||
f"It/s: {batch_cnt / (toc - tic):.3f} "
|
||||
f"Peak mem: {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
)
|
||||
@ -285,7 +289,8 @@ if __name__ == "__main__":
|
||||
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
||||
|
||||
if (i + 1) % 10 == 0:
|
||||
losses = []
|
||||
total_loss = 0
|
||||
batch_cnt = 0
|
||||
tic = time.time()
|
||||
|
||||
save_adapters("final_adapters.safetensors", flux, args)
|
||||
|
Loading…
Reference in New Issue
Block a user