Distributed FLUX (#1325)

This commit is contained in:
Angelos Katharopoulos 2025-03-24 22:16:48 -07:00 committed by GitHub
parent c243370044
commit c52cc748f8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 282 additions and 18 deletions

View File

@ -167,8 +167,9 @@ python dreambooth.py \
path/to/dreambooth/dataset/dog6
```
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
Or you can directly use the pre-processed Hugging Face dataset
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
for fine-tuning.
```shell
python dreambooth.py \
@ -210,3 +211,71 @@ speed during generation.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
Distributed Computation
------------------------
The FLUX example supports distributed computation during both generation and
training. See the [distributed communication
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
for information on how to set-up MLX for distributed communication. The rest of
this section assumes you can launch distributed MLX programs using `mlx.launch
--hostfile hostfile.json`.
### Distributed Finetuning
Distributed finetuning scales very well with FLUX and all one has to do is
adjust the gradient accumulation and training iterations so that the batch
size remains the same. For instance, to replicate the following training
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```
On 4 machines we simply run
```shell
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 2 \
mlx-community/dreambooth-dog6
```
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
### Distributed Inference
Distributed inference can be divided in two different approaches. The first
approach is the data-parallel approach, where each node generates its own
images and shares them at the end. The second approach is the model-parallel
approach where the model is shared across the nodes and they collaboratively
generate the images.
The `txt2image.py` script will attempt to choose the best approach depending on
how many images are being generated across the nodes. The model-parallel
approach can be forced by passing the argument `--force-shard`.
For better performance in the model-parallel approach we suggest that you use a
[thunderbolt
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
All you have to do once again is use `mlx.launch` as follows
```shell
mlx.launch --verbose --hostfile hostfile.json -- \
python txt2image.py --model schnell \
--n-images 8 \
--image-size 512x512 \
--verbose \
'A photo of an astronaut riding a horse on Mars'
```
for model-parallel generation you may want to also pass `--env
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
reduces the CPU/GPU synchronization overhead.

View File

@ -178,6 +178,8 @@ class DoubleStreamBlock(nn.Module):
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.sharding_group = None
def __call__(
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
) -> Tuple[mx.array, mx.array]:
@ -216,18 +218,35 @@ class DoubleStreamBlock(nn.Module):
attn = _attention(q, k, v, pe)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# Project - cat - average - split
txt_attn = self.txt_attn.proj(txt_attn)
img_attn = self.img_attn.proj(img_attn)
if self.sharding_group is not None:
attn = mx.concatenate([txt_attn, img_attn], axis=1)
attn = mx.distributed.all_sum(attn, group=self.sharding_group)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
img = img + img_mod1.gate * img_attn
img_mlp = self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
txt = txt + txt_mod1.gate * txt_attn
txt_mlp = self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
if self.sharding_group is not None:
txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)
# finalize the img/txt blocks
img = img + img_mod2.gate * img_mlp
txt = txt + txt_mod2.gate * txt_mlp
return img, txt

View File

@ -5,6 +5,7 @@ from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.distributed import shard_inplace, shard_linear
from .layers import (
DoubleStreamBlock,
@ -96,6 +97,47 @@ class Flux(nn.Module):
new_weights[k] = w
return new_weights
def shard(self, group: Optional[mx.distributed.Group] = None):
group = group or mx.distributed.init()
N = group.size()
if N == 1:
return
for block in self.double_blocks:
block.num_heads //= N
block.img_attn.num_heads //= N
block.txt_attn.num_heads //= N
block.sharding_group = group
block.img_attn.qkv = shard_linear(
block.img_attn.qkv, "all-to-sharded", segments=3, group=group
)
block.txt_attn.qkv = shard_linear(
block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
)
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
block.img_mlp.layers[0] = shard_linear(
block.img_mlp.layers[0], "all-to-sharded", group=group
)
block.txt_mlp.layers[0] = shard_linear(
block.txt_mlp.layers[0], "all-to-sharded", group=group
)
shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)
for block in self.single_blocks:
block.num_heads //= N
block.hidden_size //= N
block.linear1 = shard_linear(
block.linear1,
"all-to-sharded",
segments=[1 / 7, 2 / 7, 3 / 7],
group=group,
)
block.linear2 = shard_linear(
block.linear2, "sharded-to-all", segments=[1 / 5], group=group
)
def __call__(
self,
img: mx.array,

View File

@ -0,0 +1,109 @@
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def print_zero(group, *args, **kwargs):
if group.rank() == 0:
flush = kwargs.pop("flush", True)
print(*args, **kwargs, flush=flush)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
group = mx.distributed.init()
if group.size() > 1:
flux.flow.shard(group)
print_zero(group, "Loading models")
flux.ensure_models_are_loaded()
def print_help():
print_zero(group, "The command list:")
print_zero(group, "- 'q' to exit")
print_zero(group, "- 's HxW' to change the size of the image")
print_zero(group, "- 'n S' to change the number of steps")
print_zero(group, "- 'h' to print this help")
print_zero(group, "FLUX interactive session")
print_help()
seed = 0
size = (512, 512)
latent_size = to_latent_size(size)
steps = 50 if args.model == "dev" else 4
while True:
prompt = input(">> " if group.rank() == 0 else "")
if prompt == "q":
break
if prompt == "h":
print_help()
continue
if prompt.startswith("s "):
size = tuple([int(xi) for xi in prompt[2:].split("x")])
print_zero(group, "Setting the size to", size)
latent_size = to_latent_size(size)
continue
if prompt.startswith("n "):
steps = int(prompt[2:])
print_zero(group, "Setting the steps to", steps)
continue
seed += 1
latents = flux.generate_latents(
prompt,
n_images=1,
num_steps=steps,
latent_size=latent_size,
guidance=4.0,
seed=seed,
)
print_zero(group, "Processing prompt")
mx.eval(next(latents))
print_zero(group, "Generating latents")
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
mx.eval(xt)
print_zero(group, "Generating image")
xt = flux.decode(xt, latent_size)
xt = (xt * 255).astype(mx.uint8)
mx.eval(xt)
im = Image.fromarray(np.array(xt[0]))
im.save(args.output)
print_zero(group, "Saved at", args.output, end="\n\n")

View File

@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
@ -62,6 +62,7 @@ if __name__ == "__main__":
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
parser.add_argument("--force-shard", action="store_true")
args = parser.parse_args()
# Load the models
@ -76,6 +77,24 @@ if __name__ == "__main__":
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
# Figure out what kind of distributed generation we should do
group = mx.distributed.init()
n_images = args.n_images
should_gather = False
if group.size() > 1:
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
flux.flow.shard(group)
else:
n_images //= group.size()
should_gather = True
# If we are sharding we should have the same seed and if we are doing
# data parallel generation we should have different seeds
if args.seed is None:
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
if should_gather:
args.seed = args.seed + group.rank()
if args.preload_models:
flux.ensure_models_are_loaded()
@ -83,7 +102,7 @@ if __name__ == "__main__":
latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents(
args.prompt,
n_images=args.n_images,
n_images=n_images,
num_steps=args.steps,
latent_size=latent_size,
guidance=args.guidance,
@ -93,8 +112,8 @@ if __name__ == "__main__":
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
@ -102,36 +121,42 @@ if __name__ == "__main__":
del flux.clip
# Actual denoising loop
for x_t in tqdm(latents, total=args.steps):
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
mx.eval(x_t)
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
peak_mem_generation = mx.get_peak_memory() / 1024**3
mx.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_decoding = mx.get_peak_memory() / 1024**3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
# Gather them if each node has different images
decoded = mx.concatenate(decoded, axis=0)
if should_gather:
decoded = mx.distributed.all_gather(decoded)
mx.eval(decoded)
if args.save_raw:
*name, suffix = args.output.split(".")
name = ".".join(name)
x = mx.concatenate(decoded, axis=0)
x = decoded
x = (x * 255).astype(mx.uint8)
for i in range(len(x)):
im = Image.fromarray(np.array(x[i]))
im.save(".".join([name, str(i), suffix]))
else:
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = decoded
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
@ -143,7 +168,7 @@ if __name__ == "__main__":
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
if args.verbose and group.rank() == 0:
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")