mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Improve printing for FLUX distributed training
This commit is contained in:
parent
c117af83b8
commit
b9eff0d744
@ -261,19 +261,23 @@ if __name__ == "__main__":
|
|||||||
generate_progress_images(0, flux, args)
|
generate_progress_images(0, flux, args)
|
||||||
|
|
||||||
grads = None
|
grads = None
|
||||||
losses = []
|
batch_cnt = 0
|
||||||
|
total_loss = 0
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
for i, batch in zip(range(args.iterations), trainer.iterate(args.batch_size)):
|
||||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||||
mx.eval(loss, grads, state)
|
total_loss = total_loss + loss
|
||||||
losses.append(loss.item())
|
batch_cnt += 1
|
||||||
|
mx.eval(total_loss, grads, state)
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0 and mx.distributed.init().rank() == 0:
|
||||||
toc = time.time()
|
toc = time.time()
|
||||||
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
peak_mem = mx.metal.get_peak_memory() / 1024**3
|
||||||
|
total_loss = mx.distributed.all_sum(total_loss, stream=mx.cpu)
|
||||||
|
total_loss = total_loss.item() / batch_cnt
|
||||||
print(
|
print(
|
||||||
f"Iter: {i + 1} Loss: {sum(losses) / 10:.3f} "
|
f"Iter: {i + 1} Loss: {total_loss:.3f} "
|
||||||
f"It/s: {10 / (toc - tic):.3f} "
|
f"It/s: {batch_cnt / (toc - tic):.3f} "
|
||||||
f"Peak mem: {peak_mem:.3f} GB",
|
f"Peak mem: {peak_mem:.3f} GB",
|
||||||
flush=True,
|
flush=True,
|
||||||
)
|
)
|
||||||
@ -285,7 +289,8 @@ if __name__ == "__main__":
|
|||||||
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
save_adapters(f"{i + 1:07d}_adapters.safetensors", flux, args)
|
||||||
|
|
||||||
if (i + 1) % 10 == 0:
|
if (i + 1) % 10 == 0:
|
||||||
losses = []
|
total_loss = 0
|
||||||
|
batch_cnt = 0
|
||||||
tic = time.time()
|
tic = time.time()
|
||||||
|
|
||||||
save_adapters("final_adapters.safetensors", flux, args)
|
save_adapters("final_adapters.safetensors", flux, args)
|
||||||
|
Loading…
Reference in New Issue
Block a user