mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 09:21:18 +08:00
Reduce the number of communications
This commit is contained in:
parent
7fbd1619eb
commit
02b007f19c
@ -178,6 +178,8 @@ class DoubleStreamBlock(nn.Module):
|
||||
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
||||
)
|
||||
|
||||
self.sharding_group = None
|
||||
|
||||
def __call__(
|
||||
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
|
||||
) -> Tuple[mx.array, mx.array]:
|
||||
@ -216,18 +218,35 @@ class DoubleStreamBlock(nn.Module):
|
||||
attn = _attention(q, k, v, pe)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# Project - cat - average - split
|
||||
txt_attn = self.txt_attn.proj(txt_attn)
|
||||
img_attn = self.img_attn.proj(img_attn)
|
||||
if self.sharding_group is not None:
|
||||
attn = mx.concatenate([txt_attn, img_attn], axis=1)
|
||||
attn = mx.distributed.all_sum(attn, group=self.sharding_group)
|
||||
txt_attn, img_attn = mx.split(attn, [S], axis=1)
|
||||
|
||||
# calculate the img bloks
|
||||
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
|
||||
img = img + img_mod2.gate * self.img_mlp(
|
||||
img = img + img_mod1.gate * img_attn
|
||||
img_mlp = self.img_mlp(
|
||||
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
|
||||
)
|
||||
|
||||
# calculate the txt bloks
|
||||
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
|
||||
txt = txt + txt_mod2.gate * self.txt_mlp(
|
||||
txt = txt + txt_mod1.gate * txt_attn
|
||||
txt_mlp = self.txt_mlp(
|
||||
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
|
||||
)
|
||||
|
||||
if self.sharding_group is not None:
|
||||
txt_img = mx.concatenate([txt_mlp, img_mlp], axis=1)
|
||||
txt_img = mx.distributed.all_sum(txt_img, group=self.sharding_group)
|
||||
txt_mlp, img_mlp = mx.split(txt_img, [S], axis=1)
|
||||
|
||||
# finalize the img/txt blocks
|
||||
img = img + img_mod2.gate * img_mlp
|
||||
txt = txt + txt_mod2.gate * txt_mlp
|
||||
|
||||
return img, txt
|
||||
|
||||
|
||||
|
@ -5,7 +5,7 @@ from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.nn.layers.distributed import shard_linear
|
||||
from mlx.nn.layers.distributed import shard_inplace, shard_linear
|
||||
|
||||
from .layers import (
|
||||
DoubleStreamBlock,
|
||||
@ -107,30 +107,23 @@ class Flux(nn.Module):
|
||||
block.num_heads //= N
|
||||
block.img_attn.num_heads //= N
|
||||
block.txt_attn.num_heads //= N
|
||||
block.sharding_group = group
|
||||
block.img_attn.qkv = shard_linear(
|
||||
block.img_attn.qkv, "all-to-sharded", groups=3, group=group
|
||||
)
|
||||
block.txt_attn.qkv = shard_linear(
|
||||
block.txt_attn.qkv, "all-to-sharded", groups=3, group=group
|
||||
)
|
||||
block.img_attn.proj = shard_linear(
|
||||
block.img_attn.proj, "sharded-to-all", group=group
|
||||
)
|
||||
block.txt_attn.proj = shard_linear(
|
||||
block.txt_attn.proj, "sharded-to-all", group=group
|
||||
)
|
||||
shard_inplace(block.img_attn.proj, "sharded-to-all", group=group)
|
||||
shard_inplace(block.txt_attn.proj, "sharded-to-all", group=group)
|
||||
block.img_mlp.layers[0] = shard_linear(
|
||||
block.img_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
block.txt_mlp.layers[0] = shard_linear(
|
||||
block.txt_mlp.layers[0], "all-to-sharded", group=group
|
||||
)
|
||||
block.img_mlp.layers[2] = shard_linear(
|
||||
block.img_mlp.layers[2], "sharded-to-all", group=group
|
||||
)
|
||||
block.txt_mlp.layers[2] = shard_linear(
|
||||
block.txt_mlp.layers[2], "sharded-to-all", group=group
|
||||
)
|
||||
shard_inplace(block.img_mlp.layers[2], "sharded-to-all", group=group)
|
||||
shard_inplace(block.txt_mlp.layers[2], "sharded-to-all", group=group)
|
||||
|
||||
for block in self.single_blocks:
|
||||
block.num_heads //= N
|
||||
|
109
flux/generate_interactive.py
Normal file
109
flux/generate_interactive.py
Normal file
@ -0,0 +1,109 @@
|
||||
import argparse
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
from flux import FluxPipeline
|
||||
|
||||
|
||||
def print_zero(group, *args, **kwargs):
|
||||
if group.rank() == 0:
|
||||
flush = kwargs.pop("flush", True)
|
||||
print(*args, **kwargs, flush=flush)
|
||||
|
||||
|
||||
def quantization_predicate(name, m):
|
||||
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
|
||||
|
||||
|
||||
def to_latent_size(image_size):
|
||||
h, w = image_size
|
||||
h = ((h + 15) // 16) * 16
|
||||
w = ((w + 15) // 16) * 16
|
||||
|
||||
if (h, w) != image_size:
|
||||
print(
|
||||
"Warning: The image dimensions need to be divisible by 16px. "
|
||||
f"Changing size to {h}x{w}."
|
||||
)
|
||||
|
||||
return (h // 8, w // 8)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using stable diffusion"
|
||||
)
|
||||
parser.add_argument("--quantize", "-q", action="store_true")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
parser.add_argument("--output", default="out.png")
|
||||
args = parser.parse_args()
|
||||
|
||||
flux = FluxPipeline("flux-" + args.model, t5_padding=True)
|
||||
|
||||
if args.quantize:
|
||||
nn.quantize(flux.flow, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.t5, class_predicate=quantization_predicate)
|
||||
nn.quantize(flux.clip, class_predicate=quantization_predicate)
|
||||
|
||||
group = mx.distributed.init()
|
||||
if group.size() > 1:
|
||||
flux.flow.shard(group)
|
||||
|
||||
print_zero(group, "Loading models")
|
||||
flux.ensure_models_are_loaded()
|
||||
|
||||
def print_help():
|
||||
print_zero(group, "The command list:")
|
||||
print_zero(group, "- 'q' to exit")
|
||||
print_zero(group, "- 's HxW' to change the size of the image")
|
||||
print_zero(group, "- 'n S' to change the number of steps")
|
||||
print_zero(group, "- 'h' to print this help")
|
||||
|
||||
print_zero(group, "FLUX interactive session")
|
||||
print_help()
|
||||
seed = 0
|
||||
size = (512, 512)
|
||||
latent_size = to_latent_size(size)
|
||||
steps = 50 if args.model == "dev" else 4
|
||||
while True:
|
||||
prompt = input(">> " if group.rank() == 0 else "")
|
||||
if prompt == "q":
|
||||
break
|
||||
if prompt == "h":
|
||||
print_help()
|
||||
continue
|
||||
if prompt.startswith("s "):
|
||||
size = tuple([int(xi) for xi in prompt[2:].split("x")])
|
||||
print_zero(group, "Setting the size to", size)
|
||||
latent_size = to_latent_size(size)
|
||||
continue
|
||||
if prompt.startswith("n "):
|
||||
steps = int(prompt[2:])
|
||||
print_zero(group, "Setting the steps to", steps)
|
||||
continue
|
||||
|
||||
seed += 1
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=1,
|
||||
num_steps=steps,
|
||||
latent_size=latent_size,
|
||||
guidance=4.0,
|
||||
seed=seed,
|
||||
)
|
||||
print_zero(group, "Processing prompt")
|
||||
mx.eval(next(latents))
|
||||
print_zero(group, "Generating latents")
|
||||
for xt in tqdm(latents, total=steps, disable=group.rank() > 0):
|
||||
mx.eval(xt)
|
||||
print_zero(group, "Generating image")
|
||||
xt = flux.decode(xt, latent_size)
|
||||
xt = (xt * 255).astype(mx.uint8)
|
||||
mx.eval(xt)
|
||||
im = Image.fromarray(np.array(xt[0]))
|
||||
im.save(args.output)
|
||||
print_zero(group, "Saved at", args.output, end="\n\n")
|
Loading…
Reference in New Issue
Block a user