diff --git a/flux/flux/layers.py b/flux/flux/layers.py index 12397904..045f1e38 100644 --- a/flux/flux/layers.py +++ b/flux/flux/layers.py @@ -178,6 +178,8 @@ class DoubleStreamBlock(nn.Module): nn.Linear(mlp_hidden_dim, hidden_size, bias=True), ) + self.sharding_group = None + def __call__( self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array ) -> Tuple[mx.array, mx.array]: @@ -216,18 +218,35 @@ class DoubleStreamBlock(nn.Module): attn = _attention(q, k, v, pe) txt_attn, img_attn = mx.split(attn, [S], axis=1) + # Project - cat - average - split + txt_attn = self.txt_attn.proj(txt_attn) + img_attn = self.img_attn.proj(img_attn) + if self.sharding_group is not None: + attn = mx.concatenate([txt_attn, img_attn], axis=1) + attn = mx.distributed.all_sum(attn, group=self.sharding_group) + txt_attn, img_attn = mx.split(attn, [S], axis=1) + # calculate the img bloks - img = img + img_mod1.gate * self.img_attn.proj(img_attn) - img = img + img_mod2.gate * self.img_mlp( + img = img + img_mod1.gate * img_attn + img_mlp = self.img_mlp( (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift ) # calculate the txt bloks - txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) - txt = txt + txt_mod2.gate * self.txt_mlp( + txt = txt + txt_mod1.gate * txt_attn + txt_mlp = self.txt_mlp( (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift ) + if self.sharding_group is not None: + txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1) + txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group) + txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1) + + # finalize the img/txt blocks + img = img + img_mod2.gate * img_mlp + txt = txt + txt_mod2.gate * txt_mlp + return img, txt diff --git a/flux/flux/model.py b/flux/flux/model.py index df0903d3..bd74a393 100644 --- a/flux/flux/model.py +++ b/flux/flux/model.py @@ -5,7 +5,7 @@ from typing import Optional import mlx.core as mx import mlx.nn as nn -from mlx.nn.layers.distributed import shard_linear +from mlx.nn.layers.distributed import shard_inplace, shard_linear from .layers import ( DoubleStreamBlock, @@ -107,30 +107,23 @@ class Flux(nn.Module): block.num_heads //= N block.img_attn.num_heads //= N block.txt_attn.num_heads //= N + block.sharding_group = group block.img_attn.qkv = shard_linear( block.img_attn.qkv, "all-to-sharded", groups=3, group=group ) block.txt_attn.qkv = shard_linear( block.txt_attn.qkv, "all-to-sharded", groups=3, group=group ) - block.img_attn.proj = shard_linear( - block.img_attn.proj, "sharded-to-all", group=group - ) - block.txt_attn.proj = shard_linear( - block.txt_attn.proj, "sharded-to-all", group=group - ) + shard_inplace(block.img_attn.proj, "sharded-to-all", group=group) + shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group) block.img_mlp.layers[0] = shard_linear( block.img_mlp.layers[0], "all-to-sharded", group=group ) block.txt_mlp.layers[0] = shard_linear( block.txt_mlp.layers[0], "all-to-sharded", group=group ) - block.img_mlp.layers[2] = shard_linear( - block.img_mlp.layers[2], "sharded-to-all", group=group - ) - block.txt_mlp.layers[2] = shard_linear( - block.txt_mlp.layers[2], "sharded-to-all", group=group - ) + shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group) + shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group) for block in self.single_blocks: block.num_heads //= N diff --git a/flux/generate_interactive.py b/flux/generate_interactive.py new file mode 100644 index 00000000..448dc5c3 --- /dev/null +++ b/flux/generate_interactive.py @@ -0,0 +1,109 @@ +import argparse + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from PIL import Image +from tqdm import tqdm + +from flux import FluxPipeline + + +def print_zero(group, *args, **kwargs): + if group.rank() == 0: + flush = kwargs.pop("flush", True) + print(*args, **kwargs, flush=flush) + + +def quantization_predicate(name, m): + return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 + + +def to_latent_size(image_size): + h, w = image_size + h = ((h + 15) // 16) * 16 + w = ((w + 15) // 16) * 16 + + if (h, w) != image_size: + print( + "Warning: The image dimensions need to be divisible by 16px. " + f"Changing size to {h}x{w}." + ) + + return (h // 8, w // 8) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate images from a textual prompt using stable diffusion" + ) + parser.add_argument("--quantize", "-q", action="store_true") + parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") + parser.add_argument("--output", default="out.png") + args = parser.parse_args() + + flux = FluxPipeline("flux-" + args.model, t5_padding=True) + + if args.quantize: + nn.quantize(flux.flow, class_predicate=quantization_predicate) + nn.quantize(flux.t5, class_predicate=quantization_predicate) + nn.quantize(flux.clip, class_predicate=quantization_predicate) + + group = mx.distributed.init() + if group.size() > 1: + flux.flow.shard(group) + + print_zero(group, "Loading models") + flux.ensure_models_are_loaded() + + def print_help(): + print_zero(group, "The command list:") + print_zero(group, "- 'q' to exit") + print_zero(group, "- 's HxW' to change the size of the image") + print_zero(group, "- 'n S' to change the number of steps") + print_zero(group, "- 'h' to print this help") + + print_zero(group, "FLUX interactive session") + print_help() + seed = 0 + size = (512, 512) + latent_size = to_latent_size(size) + steps = 50 if args.model == "dev" else 4 + while True: + prompt = input(">> " if group.rank() == 0 else "") + if prompt == "q": + break + if prompt == "h": + print_help() + continue + if prompt.startswith("s "): + size = tuple([int(xi) for xi in prompt[2:].split("x")]) + print_zero(group, "Setting the size to", size) + latent_size = to_latent_size(size) + continue + if prompt.startswith("n "): + steps = int(prompt[2:]) + print_zero(group, "Setting the steps to", steps) + continue + + seed += 1 + latents = flux.generate_latents( + prompt, + n_images=1, + num_steps=steps, + latent_size=latent_size, + guidance=4.0, + seed=seed, + ) + print_zero(group, "Processing prompt") + mx.eval(next(latents)) + print_zero(group, "Generating latents") + for xt in tqdm(latents, total=steps, disable=group.rank() > 0): + mx.eval(xt) + print_zero(group, "Generating image") + xt = flux.decode(xt, latent_size) + xt = (xt * 255).astype(mx.uint8) + mx.eval(xt) + im = Image.fromarray(np.array(xt[0])) + im.save(args.output) + print_zero(group, "Saved at", args.output, end="\n\n")