mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Fix data parallel generation
This commit is contained in:
parent
208856520d
commit
a1e259607e
@ -126,7 +126,7 @@ if __name__ == "__main__":
|
|||||||
|
|
||||||
# Decode them into images
|
# Decode them into images
|
||||||
decoded = []
|
decoded = []
|
||||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
|
||||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
||||||
mx.eval(decoded[-1])
|
mx.eval(decoded[-1])
|
||||||
peak_mem_decoding = mx.get_peak_memory() / 1024**3
|
peak_mem_decoding = mx.get_peak_memory() / 1024**3
|
||||||
@ -162,7 +162,7 @@ if __name__ == "__main__":
|
|||||||
im.save(args.output)
|
im.save(args.output)
|
||||||
|
|
||||||
# Report the peak memory used during generation
|
# Report the peak memory used during generation
|
||||||
if args.verbose:
|
if args.verbose and group.rank() == 0:
|
||||||
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
|
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
|
||||||
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
||||||
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
||||||
|
Loading…
Reference in New Issue
Block a user