From a1e259607e8c7f4f1fd27effb145e5cf3be145f1 Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Sat, 22 Mar 2025 16:43:40 -0700 Subject: [PATCH] Fix data parallel generation --- flux/txt2image.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flux/txt2image.py b/flux/txt2image.py index 98cd8633..2d3857e2 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -126,7 +126,7 @@ if __name__ == "__main__": # Decode them into images decoded = [] - for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): + for i in tqdm(range(0, n_images, args.decoding_batch_size)): decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) mx.eval(decoded[-1]) peak_mem_decoding = mx.get_peak_memory() / 1024**3 @@ -162,7 +162,7 @@ if __name__ == "__main__": im.save(args.output) # Report the peak memory used during generation - if args.verbose: + if args.verbose and group.rank() == 0: print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB") print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")