mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 01:17:28 +08:00
Distributed FLUX (#1325)
This commit is contained in:
parent
c243370044
commit
c52cc748f8
@ -167,8 +167,9 @@ python dreambooth.py \
|
||||
path/to/dreambooth/dataset/dog6
|
||||
```
|
||||
|
||||
|
||||
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
|
||||
Or you can directly use the pre-processed Hugging Face dataset
|
||||
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
|
||||
for fine-tuning.
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
@ -210,3 +211,71 @@ speed during generation.
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
|
||||
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
|
||||
|
||||
|
||||
Distributed Computation
|
||||
------------------------
|
||||
|
||||
The FLUX example supports distributed computation during both generation and
|
||||
training. See the [distributed communication
|
||||
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
|
||||
for information on how to set-up MLX for distributed communication. The rest of
|
||||
this section assumes you can launch distributed MLX programs using `mlx.launch
|
||||
--hostfile hostfile.json`.
|
||||
|
||||
### Distributed Finetuning
|
||||
|
||||
Distributed finetuning scales very well with FLUX and all one has to do is
|
||||
adjust the gradient accumulation and training iterations so that the batch
|
||||
size remains the same. For instance, to replicate the following training
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 8 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
On 4 machines we simply run
|
||||
|
||||
```shell
|
||||
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 2 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
|
||||
|
||||
### Distributed Inference
|
||||
|
||||
Distributed inference can be divided in two different approaches. The first
|
||||
approach is the data-parallel approach, where each node generates its own
|
||||
images and shares them at the end. The second approach is the model-parallel
|
||||
approach where the model is shared across the nodes and they collaboratively
|
||||
generate the images.
|
||||
|
||||
The `txt2image.py` script will attempt to choose the best approach depending on
|
||||
how many images are being generated across the nodes. The model-parallel
|
||||
approach can be forced by passing the argument `--force-shard`.
|
||||
|
||||
For better performance in the model-parallel approach we suggest that you use a
|
||||
[thunderbolt
|
||||
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
|
||||
|
||||
All you have to do once again is use `mlx.launch` as follows
|
||||
|
||||
```shell
|
||||
mlx.launch --verbose --hostfile hostfile.json -- \
|
||||
python txt2image.py --model schnell \
|
||||
--n-images 8 \
|
||||
--image-size 512x512 \
|
||||
--verbose \
|
||||
'A photo of an astronaut riding a horse on Mars'
|
||||
```
|
||||
|
||||
for model-parallel generation you may want to also pass `--env
|
||||
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
|
||||
reduces the CPU/GPU synchronization overhead.
|
||||
|
@ -178,6 +178,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.sharding_group = None
|
||||
|
||||
def __call__(
|
||||
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
@ -216,18 +218,35 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = _attention(q, k, v, pe)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# Project - cat - average - split
|
||||
txt_attn = self.txt_attn.proj(txt_attn)
|
||||
img_attn = self.img_attn.proj(img_attn)
|
||||
if self.sharding_group is not None:
|
||||
attn = mx.concatenate([txt_attn, img_attn], axis=1)
|
||||
attn = mx.distributed.all_sum(attn, group=self.sharding_group)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp(
|
||||
img = img + img_mod1.gate * img_attn
|
||||
img_mlp = self.img_mlp(
|
||||
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
||||
)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp(
|
||||
txt = txt + txt_mod1.gate * txt_attn
|
||||
txt_mlp = self.txt_mlp(
|
||||
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
||||
)
|
||||
|
||||
if self.sharding_group is not None:
|
||||
txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
|
||||
txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
|
||||
txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)
|
||||
|
||||
# finalize the img/txt blocks
|
||||
img = img + img_mod2.gate * img_mlp
|
||||
txt = txt + txt_mod2.gate * txt_mlp
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
|
@ -5,6 +5,7 @@ from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@ -96,6 +97,47 @@ class Flux(nn.Module):
|
||||
new_weights[k] = w
|
||||
return new_weights
|
||||
|
||||
def shard(self, group: Optional[mx.distributed.Group] = None):
|
||||
group = group or mx.distributed.init()
|
||||
N = group.size()
|
||||
if N == 1:
|
||||
return
|
||||
|
||||
for block in self.double_blocks:
|
||||
block.num_heads //= N
|
||||
block.img_attn.num_heads //= N
|
||||
block.txt_attn.num_heads //= N
|
||||
block.sharding_group = group
|
||||
block.img_attn.qkv = shard_linear(
|
||||
block.img_attn.qkv, "all-to-sharded", segments=3, group=group
|
||||
)
|
||||
block.txt_attn.qkv = shard_linear(
|
||||
block.txt_attn.qkv, "all-to-sharded", segments=3, group=group
|
||||
)
|
||||
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
|
||||
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
|
||||
block.img_mlp.layers[0] = shard_linear(
|
||||
block.img_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
block.txt_mlp.layers[0] = shard_linear(
|
||||
block.txt_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
|
||||
shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.num_heads //= N
|
||||
block.hidden_size //= N
|
||||
block.linear1 = shard_linear(
|
||||
block.linear1,
|
||||
"all-to-sharded",
|
||||
segments=[1 / 7, 2 / 7, 3 / 7],
|
||||
group=group,
|
||||
)
|
||||
block.linear2 = shard_linear(
|
||||
block.linear2, "sharded-to-all", segments=[1 / 5], group=group
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
img: mx.array,
|
||||
|
109
flux/generate_interactive.py
Normal file
109
flux/generate_interactive.py
Normal file
@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
def print_zero(group, *args, **kwargs):
|
||||
if group.rank() == 0:
|
||||
flush = kwargs.pop("flush", True)
|
||||
print(*args, **kwargs, flush=flush)
|
||||
|
||||
|
||||
def quantization_predicate(name, m):
|
||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
||||
|
||||
|
||||
def to_latent_size(image_size):
|
||||
h, w = image_size
|
||||
h = ((h + 15) // 16) * 16
|
||||
w = ((w + 15) // 16) * 16
|
||||
|
||||
if (h, w) != image_size:
|
||||
print(
|
||||
"Warning: The image dimensions need to be divisible by 16px. "
|
||||
f"Changing size to {h}x{w}."
|
||||
)
|
||||
|
||||
return (h // 8, w // 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using FLUX"
|
||||
)
|
||||
parser.add_argument("--quantize", "-q", action="store_true")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
args = parser.parse_args()
|
||||
|
||||
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
|
||||
|
||||
if args.quantize:
|
||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
group = mx.distributed.init()
|
||||
if group.size() > 1:
|
||||
flux.flow.shard(group)
|
||||
|
||||
print_zero(group, "Loading models")
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
def print_help():
|
||||
print_zero(group, "The command list:")
|
||||
print_zero(group, "- 'q' to exit")
|
||||
print_zero(group, "- 's HxW' to change the size of the image")
|
||||
print_zero(group, "- 'n S' to change the number of steps")
|
||||
print_zero(group, "- 'h' to print this help")
|
||||
|
||||
print_zero(group, "FLUX interactive session")
|
||||
print_help()
|
||||
seed = 0
|
||||
size = (512, 512)
|
||||
latent_size = to_latent_size(size)
|
||||
steps = 50 if args.model == "dev" else 4
|
||||
while True:
|
||||
prompt = input(">> " if group.rank() == 0 else "")
|
||||
if prompt == "q":
|
||||
break
|
||||
if prompt == "h":
|
||||
print_help()
|
||||
continue
|
||||
if prompt.startswith("s "):
|
||||
size = tuple([int(xi) for xi in prompt[2:].split("x")])
|
||||
print_zero(group, "Setting the size to", size)
|
||||
latent_size = to_latent_size(size)
|
||||
continue
|
||||
if prompt.startswith("n "):
|
||||
steps = int(prompt[2:])
|
||||
print_zero(group, "Setting the steps to", steps)
|
||||
continue
|
||||
|
||||
seed += 1
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=1,
|
||||
num_steps=steps,
|
||||
latent_size=latent_size,
|
||||
guidance=4.0,
|
||||
seed=seed,
|
||||
)
|
||||
print_zero(group, "Processing prompt")
|
||||
mx.eval(next(latents))
|
||||
print_zero(group, "Generating latents")
|
||||
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
|
||||
mx.eval(xt)
|
||||
print_zero(group, "Generating image")
|
||||
xt = flux.decode(xt, latent_size)
|
||||
xt = (xt * 255).astype(mx.uint8)
|
||||
mx.eval(xt)
|
||||
im = Image.fromarray(np.array(xt[0]))
|
||||
im.save(args.output)
|
||||
print_zero(group, "Saved at", args.output, end="\n\n")
|
@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using stable diffusion"
|
||||
description="Generate images from a textual prompt using FLUX"
|
||||
)
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
@ -62,6 +62,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--adapter")
|
||||
parser.add_argument("--fuse-adapter", action="store_true")
|
||||
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
|
||||
parser.add_argument("--force-shard", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load the models
|
||||
@ -76,6 +77,24 @@ if __name__ == "__main__":
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
# Figure out what kind of distributed generation we should do
|
||||
group = mx.distributed.init()
|
||||
n_images = args.n_images
|
||||
should_gather = False
|
||||
if group.size() > 1:
|
||||
if args.force_shard or n_images < group.size() or n_images % group.size() != 0:
|
||||
flux.flow.shard(group)
|
||||
else:
|
||||
n_images //= group.size()
|
||||
should_gather = True
|
||||
|
||||
# If we are sharding we should have the same seed and if we are doing
|
||||
# data parallel generation we should have different seeds
|
||||
if args.seed is None:
|
||||
args.seed = mx.distributed.all_sum(mx.random.randint(0, 2**20)).item()
|
||||
if should_gather:
|
||||
args.seed = args.seed + group.rank()
|
||||
|
||||
if args.preload_models:
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
@ -83,7 +102,7 @@ if __name__ == "__main__":
|
||||
latent_size = to_latent_size(args.image_size)
|
||||
latents = flux.generate_latents(
|
||||
args.prompt,
|
||||
n_images=args.n_images,
|
||||
n_images=n_images,
|
||||
num_steps=args.steps,
|
||||
latent_size=latent_size,
|
||||
guidance=args.guidance,
|
||||
@ -93,8 +112,8 @@ if __name__ == "__main__":
|
||||
# First we get and eval the conditioning
|
||||
conditioning = next(latents)
|
||||
mx.eval(conditioning)
|
||||
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
|
||||
mx.metal.reset_peak_memory()
|
||||
peak_mem_conditioning = mx.get_peak_memory() / 1024**3
|
||||
mx.reset_peak_memory()
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the text encoders.
|
||||
@ -102,36 +121,42 @@ if __name__ == "__main__":
|
||||
del flux.clip
|
||||
|
||||
# Actual denoising loop
|
||||
for x_t in tqdm(latents, total=args.steps):
|
||||
for x_t in tqdm(latents, total=args.steps, disable=group.rank() > 0):
|
||||
mx.eval(x_t)
|
||||
|
||||
# The following is not necessary but it may help in memory constrained
|
||||
# systems by reusing the memory kept by the flow transformer.
|
||||
del flux.flow
|
||||
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
|
||||
mx.metal.reset_peak_memory()
|
||||
peak_mem_generation = mx.get_peak_memory() / 1024**3
|
||||
mx.reset_peak_memory()
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
for i in tqdm(range(0, n_images, args.decoding_batch_size)):
|
||||
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
|
||||
mx.eval(decoded[-1])
|
||||
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
|
||||
peak_mem_decoding = mx.get_peak_memory() / 1024**3
|
||||
peak_mem_overall = max(
|
||||
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
|
||||
)
|
||||
|
||||
# Gather them if each node has different images
|
||||
decoded = mx.concatenate(decoded, axis=0)
|
||||
if should_gather:
|
||||
decoded = mx.distributed.all_gather(decoded)
|
||||
mx.eval(decoded)
|
||||
|
||||
if args.save_raw:
|
||||
*name, suffix = args.output.split(".")
|
||||
name = ".".join(name)
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = decoded
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
for i in range(len(x)):
|
||||
im = Image.fromarray(np.array(x[i]))
|
||||
im.save(".".join([name, str(i), suffix]))
|
||||
else:
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = decoded
|
||||
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
@ -143,7 +168,7 @@ if __name__ == "__main__":
|
||||
im.save(args.output)
|
||||
|
||||
# Report the peak memory used during generation
|
||||
if args.verbose:
|
||||
if args.verbose and group.rank() == 0:
|
||||
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
|
||||
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
|
||||
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
|
||||
|
Loading…
Reference in New Issue
Block a user