From c52cc748f876b898720720dd0e4c4632c9790ebd Mon Sep 17 00:00:00 2001 From: Angelos Katharopoulos Date: Mon, 24 Mar 2025 22:16:48 -0700 Subject: [PATCH] Distributed FLUX (#1325) --- flux/README.md | 73 ++++++++++++++++++++++- flux/flux/layers.py | 27 +++++++-- flux/flux/model.py | 42 ++++++++++++++ flux/generate_interactive.py | 109 +++++++++++++++++++++++++++++++++++ flux/txt2image.py | 49 ++++++++++++---- 5 files changed, 282 insertions(+), 18 deletions(-) create mode 100644 flux/generate_interactive.py diff --git a/flux/README.md b/flux/README.md index b00a9621..95f86b49 100644 --- a/flux/README.md +++ b/flux/README.md @@ -167,8 +167,9 @@ python dreambooth.py \ path/to/dreambooth/dataset/dog6 ``` - -Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning. +Or you can directly use the pre-processed Hugging Face dataset +[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) +for fine-tuning. ```shell python dreambooth.py \ @@ -210,3 +211,71 @@ speed during generation. [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details. [^2]: The images are from unsplash by https://unsplash.com/@alvannee . + + +Distributed Computation +------------------------ + +The FLUX example supports distributed computation during both generation and +training. See the [distributed communication +documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) +for information on how to set-up MLX for distributed communication. The rest of +this section assumes you can launch distributed MLX programs using `mlx.launch +--hostfile hostfile.json`. + +### Distributed Finetuning + +Distributed finetuning scales very well with FLUX and all one has to do is +adjust the gradient accumulation and training iterations so that the batch +size remains the same. For instance, to replicate the following training + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + mlx-community/dreambooth-dog6 +``` + +On 4 machines we simply run + +```shell +mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 150 --iterations 300 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 2 \ + mlx-community/dreambooth-dog6 +``` + +Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8. + +### Distributed Inference + +Distributed inference can be divided in two different approaches. The first +approach is the data-parallel approach, where each node generates its own +images and shares them at the end. The second approach is the model-parallel +approach where the model is shared across the nodes and they collaboratively +generate the images. + +The `txt2image.py` script will attempt to choose the best approach depending on +how many images are being generated across the nodes. The model-parallel +approach can be forced by passing the argument `--force-shard`. + +For better performance in the model-parallel approach we suggest that you use a +[thunderbolt +ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring). + +All you have to do once again is use `mlx.launch` as follows + +```shell +mlx.launch --verbose --hostfile hostfile.json -- \ + python txt2image.py --model schnell \ + --n-images 8 \ + --image-size 512x512 \ + --verbose \ + 'A photo of an astronaut riding a horse on Mars' +``` + +for model-parallel generation you may want to also pass `--env +MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that +reduces the CPU/GPU synchronization overhead. diff --git a/flux/flux/layers.py b/flux/flux/layers.py index 12397904..045f1e38 100644 --- a/flux/flux/layers.py +++ b/flux/flux/layers.py @@ -178,6 +178,8 @@ class DoubleStreamBlock(nn.Module): nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) + self.sharding_group = None + def __call__( self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array ) -> Tuple[mx.array, mx.array]: @@ -216,18 +218,35 @@ class DoubleStreamBlock(nn.Module): attn = _attention(q, k, v, pe) txt_attn, img_attn = mx.split(attn, [S], axis=1) + # Project - cat - average - split + txt_attn = self.txt_attn.proj(txt_attn) + img_attn = self.img_attn.proj(img_attn) + if self.sharding_group is not None: + attn = mx.concatenate([txt_attn, img_attn], axis=1) + attn = mx.distributed.all_sum(attn, group=self.sharding_group) + txt_attn, img_attn = mx.split(attn, [S], axis=1) + # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp( + img = img + img_mod1.gate * img_attn + img_mlp = self.img_mlp( (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift ) # calculate the txt bloks - txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt = txt + txt_mod2.gate * self.txt_mlp( + txt = txt + txt_mod1.gate * txt_attn + txt_mlp = self.txt_mlp( (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift ) + if self.sharding_group is not None: + txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1) + txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group) + txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1) + + # finalize the img/txt blocks + img = img + img_mod2.gate * img_mlp + txt = txt + txt_mod2.gate * txt_mlp + return img, txt diff --git a/flux/flux/model.py b/flux/flux/model.py index d8ad9d9b..c524edf3 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -5,6 +5,7 @@ from typing import Optional import mlx.core as mx import mlx.nn as nn +from mlx.nn.layers.distributed import shard_inplace, shard_linear from .layers import ( DoubleStreamBlock, @@ -96,6 +97,47 @@ class Flux(nn.Module): new_weights[k] = w return new_weights + def shard(self, group: Optional[mx.distributed.Group] = None): + group = group or mx.distributed.init() + N = group.size() + if N == 1: + return + + for block in self.double_blocks: + block.num_heads //= N + block.img_attn.num_heads //= N + block.txt_attn.num_heads //= N + block.sharding_group = group + block.img_attn.qkv = shard_linear( + block.img_attn.qkv, "all-to-sharded", segments=3, group=group + ) + block.txt_attn.qkv = shard_linear( + block.txt_attn.qkv, "all-to-sharded", segments=3, group=group + ) + shard_inplace(block.img_attn.proj, "sharded-to-all", group=group) + shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group) + block.img_mlp.layers[0] = shard_linear( + block.img_mlp.layers[0], "all-to-sharded", group=group + ) + block.txt_mlp.layers[0] = shard_linear( + block.txt_mlp.layers[0], "all-to-sharded", group=group + ) + shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group) + shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group) + + for block in self.single_blocks: + block.num_heads //= N + block.hidden_size //= N + block.linear1 = shard_linear( + block.linear1, + "all-to-sharded", + segments=[1 / 7, 2 / 7, 3 / 7], + group=group, + ) + block.linear2 = shard_linear( + block.linear2, "sharded-to-all", segments=[1 / 5], group=group + ) + def __call__( self, img: mx.array, diff --git a/flux/generate_interactive.py b/flux/generate_interactive.py new file mode 100644 index 00000000..9acde33c --- /dev/null +++ b/flux/generate_interactive.py @@ -0,0 +1,109 @@ +import argparse + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from PIL import Image +from tqdm import tqdm + +from flux import FluxPipeline + + +def print_zero(group, *args, **kwargs): + if group.rank() == 0: + flush = kwargs.pop("flush", True) + print(*args, **kwargs, flush=flush) + + +def quantization_predicate(name, m): + return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 + + +def to_latent_size(image_size): + h, w = image_size + h = ((h + 15) // 16) * 16 + w = ((w + 15) // 16) * 16 + + if (h, w) != image_size: + print( + "Warning: The image dimensions need to be divisible by 16px. " + f"Changing size to {h}x{w}." + ) + + return (h // 8, w // 8) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate images from a textual prompt using FLUX" + ) + parser.add_argument("--quantize", "-q", action="store_true") + parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") + parser.add_argument("--output", default="out.png") + args = parser.parse_args() + + flux = FluxPipeline("flux-" + args.model, t5_padding=True) + + if args.quantize: + nn.quantize(flux.flow, class_predicate=quantization_predicate) + nn.quantize(flux.t5, class_predicate=quantization_predicate) + nn.quantize(flux.clip, class_predicate=quantization_predicate) + + group = mx.distributed.init() + if group.size() > 1: + flux.flow.shard(group) + + print_zero(group, "Loading models") + flux.ensure_models_are_loaded() + + def print_help(): + print_zero(group, "The command list:") + print_zero(group, "- 'q' to exit") + print_zero(group, "- 's HxW' to change the size of the image") + print_zero(group, "- 'n S' to change the number of steps") + print_zero(group, "- 'h' to print this help") + + print_zero(group, "FLUX interactive session") + print_help() + seed = 0 + size = (512, 512) + latent_size = to_latent_size(size) + steps = 50 if args.model == "dev" else 4 + while True: + prompt = input(">> " if group.rank() == 0 else "") + if prompt == "q": + break + if prompt == "h": + print_help() + continue + if prompt.startswith("s "): + size = tuple([int(xi) for xi in prompt[2:].split("x")]) + print_zero(group, "Setting the size to", size) + latent_size = to_latent_size(size) + continue + if prompt.startswith("n "): + steps = int(prompt[2:]) + print_zero(group, "Setting the steps to", steps) + continue + + seed += 1 + latents = flux.generate_latents( + prompt, + n_images=1, + num_steps=steps, + latent_size=latent_size, + guidance=4.0, + seed=seed, + ) + print_zero(group, "Processing prompt") + mx.eval(next(latents)) + print_zero(group, "Generating latents") + for xt in tqdm(latents, total=steps, disable=group.rank() > 0): + mx.eval(xt) + print_zero(group, "Generating image") + xt = flux.decode(xt, latent_size) + xt = (xt * 255).astype(mx.uint8) + mx.eval(xt) + im = Image.fromarray(np.array(xt[0])) + im.save(args.output) + print_zero(group, "Saved at", args.output, end="\n\n") diff --git a/flux/txt2image.py b/flux/txt2image.py index 5ebec81a..cae0a6d9 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate images from a textual prompt using stable diffusion" + description="Generate images from a textual prompt using FLUX" ) parser.add_argument("prompt") parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") @@ -62,6 +62,7 @@ if __name__ == "__main__": parser.add_argument("--adapter") parser.add_argument("--fuse-adapter", action="store_true") parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") + parser.add_argument("--force-shard", action="store_true") args = parser.parse_args() # Load the models @@ -76,6 +77,24 @@ if __name__ == "__main__": nn.quantize(flux.t5, class_predicate=quantization_predicate) nn.quantize(flux.clip, class_predicate=quantization_predicate) + # Figure out what kind of distributed generation we should do + group = mx.distributed.init() + n_images = args.n_images + should_gather = False + if group.size() > 1: + if args.force_shard or n_images < group.size() or n_images % group.size() != 0: + flux.flow.shard(group) + else: + n_images //= group.size() + should_gather = True + + # If we are sharding we should have the same seed and if we are doing + # data parallel generation we should have different seeds + if args.seed is None: + args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item() + if should_gather: + args.seed = args.seed + group.rank() + if args.preload_models: flux.ensure_models_are_loaded() @@ -83,7 +102,7 @@ if __name__ == "__main__": latent_size = to_latent_size(args.image_size) latents = flux.generate_latents( args.prompt, - n_images=args.n_images, + n_images=n_images, num_steps=args.steps, latent_size=latent_size, guidance=args.guidance, @@ -93,8 +112,8 @@ if __name__ == "__main__": # First we get and eval the conditioning conditioning = next(latents) mx.eval(conditioning) - peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 - mx.metal.reset_peak_memory() + peak_mem_conditioning = mx.get_peak_memory() / 1024**3 + mx.reset_peak_memory() # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the text encoders. @@ -102,36 +121,42 @@ if __name__ == "__main__": del flux.clip # Actual denoising loop - for x_t in tqdm(latents, total=args.steps): + for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0): mx.eval(x_t) # The following is not necessary but it may help in memory constrained # systems by reusing the memory kept by the flow transformer. del flux.flow - peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 - mx.metal.reset_peak_memory() + peak_mem_generation = mx.get_peak_memory() / 1024**3 + mx.reset_peak_memory() # Decode them into images decoded = [] - for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): + for i in tqdm(range(0, n_images, args.decoding_batch_size)): decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) mx.eval(decoded[-1]) - peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 + peak_mem_decoding = mx.get_peak_memory() / 1024**3 peak_mem_overall = max( peak_mem_conditioning, peak_mem_generation, peak_mem_decoding ) + # Gather them if each node has different images + decoded = mx.concatenate(decoded, axis=0) + if should_gather: + decoded = mx.distributed.all_gather(decoded) + mx.eval(decoded) + if args.save_raw: *name, suffix = args.output.split(".") name = ".".join(name) - x = mx.concatenate(decoded, axis=0) + x = decoded x = (x * 255).astype(mx.uint8) for i in range(len(x)): im = Image.fromarray(np.array(x[i])) im.save(".".join([name, str(i), suffix])) else: # Arrange them on a grid - x = mx.concatenate(decoded, axis=0) + x = decoded x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) B, H, W, C = x.shape x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) @@ -143,7 +168,7 @@ if __name__ == "__main__": im.save(args.output) # Report the peak memory used during generation - if args.verbose: + if args.verbose and group.rank() == 0: print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB") print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")