Add better memory usage reporting

This commit is contained in:
Angelos Katharopoulos 2024-10-11 20:52:12 -07:00
parent 7c8c5818f7
commit 4bd0294598

View File

@ -94,6 +94,7 @@ if __name__ == "__main__":
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
@ -108,13 +109,17 @@ if __name__ == "__main__":
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
if args.save_raw:
*name, suffix = args.output.split(".")
@ -139,6 +144,7 @@ if __name__ == "__main__":
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the text: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")