Distributed FLUX (#1325)

This commit is contained in:
Angelos Katharopoulos
2025-03-24 22:16:48 -07:00
committed by GitHub
parent c243370044
commit c52cc748f8
5 changed files with 282 additions and 18 deletions

View File

@@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
@@ -62,6 +62,7 @@ if __name__ == "__main__":
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
parser.add_argument("--force-shard", action="store_true")
args = parser.parse_args()
# Load the models
@@ -76,6 +77,24 @@ if __name__ == "__main__":
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
# Figure out what kind of distributed generation we should do
group = mx.distributed.init()
n_images = args.n_images
should_gather = False
if group.size() > 1:
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
flux.flow.shard(group)
else:
n_images //= group.size()
should_gather = True
# If we are sharding we should have the same seed and if we are doing
# data parallel generation we should have different seeds
if args.seed is None:
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
if should_gather:
args.seed = args.seed + group.rank()
if args.preload_models:
flux.ensure_models_are_loaded()
@@ -83,7 +102,7 @@ if __name__ == "__main__":
latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents(
args.prompt,
n_images=args.n_images,
n_images=n_images,
num_steps=args.steps,
latent_size=latent_size,
guidance=args.guidance,
@@ -93,8 +112,8 @@ if __name__ == "__main__":
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
@@ -102,36 +121,42 @@ if __name__ == "__main__":
del flux.clip
# Actual denoising loop
for x_t in tqdm(latents, total=args.steps):
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
mx.eval(x_t)
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
peak_mem_generation = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_decoding = mx.get_peak_memory() / 1024**3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
# Gather them if each node has different images
decoded = mx.concatenate(decoded, axis=0)
if should_gather:
decoded = mx.distributed.all_gather(decoded)
mx.eval(decoded)
if args.save_raw:
*name, suffix = args.output.split(".")
name = ".".join(name)
x = mx.concatenate(decoded, axis=0)
x = decoded
x = (x * 255).astype(mx.uint8)
for i in range(len(x)):
im = Image.fromarray(np.array(x[i]))
im.save(".".join([name, str(i), suffix]))
else:
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = decoded
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
@@ -143,7 +168,7 @@ if __name__ == "__main__":
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
if args.verbose and group.rank() == 0:
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")