Fix data parallel generation

This commit is contained in:
Angelos Katharopoulos 2025-03-22 16:43:40 -07:00
parent 208856520d
commit a1e259607e

View File

@ -126,7 +126,7 @@ if __name__ == "__main__":
# Decode them into images # Decode them into images
decoded = [] decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): for i in tqdm(range(0, n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1]) mx.eval(decoded[-1])
peak_mem_decoding = mx.get_peak_memory() / 1024**3 peak_mem_decoding = mx.get_peak_memory() / 1024**3
@ -162,7 +162,7 @@ if __name__ == "__main__":
im.save(args.output) im.save(args.output)
# Report the peak memory used during generation # Report the peak memory used during generation
if args.verbose: if args.verbose and group.rank() == 0:
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB") print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB") print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")