Add FLUX finetuning (#1028)

This commit is contained in:
Angelos Katharopoulos 2024-10-11 21:17:41 -07:00 committed by GitHub
parent d72fdeb4ee
commit a5f2bab070
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
18 changed files with 2685 additions and 0 deletions

185
flux/README.md Normal file
View File

@ -0,0 +1,185 @@
FLUX
====
FLUX implementation in MLX. The implementation is ported directly from
[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux)
and the model weights are downloaded directly from the Hugging Face Hub.
The goal of this example is to be clean, educational and to allow for
experimentation with finetuning FLUX models as well as adding extra
functionality such as in-/outpainting, guidance with custom losses etc.
![MLX image](static/generated-mlx.png)
*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of
tron emanating futuristic technology with the word "MLX" in the center with
capital red letters.'*
Installation
------------
The dependencies are minimal, namely:
- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
- `sentencepiece` for the T5 tokenizer
You can install all of the above with the `requirements.txt` as follows:
pip install -r requirements.txt
Inference
---------
Inference in this example is similar to the stable diffusion example. The
classes to get you started are `FluxPipeline` from the `flux` module.
```python
import mlx.core as mx
from flux import FluxPipeline
# This will download all the weights from HF hub
flux = FluxPipeline("flux-schnell")
# Make a generator that returns the latent variables from the reverse diffusion
# process
latent_generator = flux.generate_latents(
"A photo of an astronaut riding a horse on Mars",
num_steps=4,
latent_size=(32, 64), # 256x512 image
)
# The first return value of the generator contains the conditioning and the
# random noise at the beginning of the diffusion process.
conditioning = next(latent_generator)
(
x_T, # The initial noise
x_positions, # The integer positions used for image positional encoding
t5_conditioning, # The T5 features from the text prompt
t5_positions, # Integer positions for text (normally all 0s)
clip_conditioning, # The clip text features from the text prompt
) = conditioning
# Returning the conditioning as the first output from the generator allows us
# to unload T5 and clip before running the diffusion transformer.
mx.eval(conditioning)
# Evaluate each diffusion step
for x_t in latent_generator:
mx.eval(x_t)
# Note that we need to pass the latent size because it is collapsed and
# patchified in x_t and we need to unwrap it.
img = flux.decode(x_t, latent_size=(32, 64))
```
The above are essentially the implementation of the `txt2image.py` script
except for some additional logic to quantize and/or load trained adapters. One
can use the script as follows:
```shell
python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.'
```
### Experimental Options
FLUX pads the prompt to a specific size of 512 tokens for the dev model and
256 for the schnell model. Not applying padding results in faster generation
but it is not clear how it may affect the generated images. To enable that
option in this example pass `--no-t5-padding` to the `txt2image.py` script or
instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`.
Finetuning
----------
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an
`index.json` file with the following format:
```json
{
"data": [
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
...
]
}
```
The training script by default trains for 600 iterations with a batch size of
1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py
--help` for the list of hyperparameters you can tune.
> [!Note]
> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and
> should reduce this number significantly.
### Training Example
This is a step-by-step finetuning example. We will be using the data from
[https://github.com/google/dreambooth](https://github.com/google/dreambooth).
In particular, we will use `dog6` which is a popular example for showcasing
dreambooth [^1].
The training images are the following 5 images [^2]:
![dog6](static/dog6.png)
We start by making the following `index.json` file and placing it in the same
folder as the images.
```json
{
"data": [
{"image": "00.jpg", "text": "A photo of sks dog"},
{"image": "01.jpg", "text": "A photo of sks dog"},
{"image": "02.jpg", "text": "A photo of sks dog"},
{"image": "03.jpg", "text": "A photo of sks dog"},
{"image": "04.jpg", "text": "A photo of sks dog"}
]
}
```
Subsequently we finetune FLUX using the following command:
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
path/to/dreambooth/dataset/dog6
```
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
bit more than 1 hour.
### Using the Adapter
The adapters are saved in `mlx_output` and can be used directly by the
`txt2image.py` script. For instance,
```shell
python txt2img.py --model dev --save-raw --image-size 512x512 --n-images 1 \
--adapter mlx_output/mlx_output/0001200_adapters.safetensors \
--fuse-adapter \
--no-t5-padding \
'A photo of an sks dog lying on the sand at a beach in Greece'
```
generates an image that looks like the following,
![dog image](static/dog-r4-g8-1200.png)
and of course we can pass `--image-size 512x1024` to get larger images with
different aspect ratios,
![wide dog image](static/dog-r4-g8-1200-512x1024.png)
The arguments that are relevant to the adapters are of course `--adapter` and
`--fuse-adapter`. The first defines the path to an adapter to apply to the
model and the second fuses the adapter back into the model to get a bit more
speed during generation.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .

378
flux/dreambooth.py Normal file
View File

@ -0,0 +1,378 @@
# Copyright © 2024 Apple Inc.
import argparse
import json
import time
from functools import partial
from pathlib import Path
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
import numpy as np
from mlx.nn.utils import average_gradients
from mlx.utils import tree_flatten, tree_map, tree_reduce
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
class FinetuningDataset:
def __init__(self, flux, args):
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]
def generate_progress_images(iteration, flux, args):
"""Generate images to monitor the progress of the finetuning."""
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_progress.png"
print(f"Generating {str(out_file)}", flush=True)
# Generate some images and arrange them in a grid
n_rows = 2
n_images = 4
x = flux.generate_images(
args.progress_prompt,
n_images,
args.progress_steps,
)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(n_rows * H, B // n_rows * W, C)
x = mx.pad(x, [(4, 4), (4, 4), (0, 0)])
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(out_file)
def save_adapters(iteration, flux, args):
out_dir = Path(args.output_dir)
out_dir.mkdir(parents=True, exist_ok=True)
out_file = out_dir / f"{iteration:07d}_adapters.safetensors"
print(f"Saving {str(out_file)}")
mx.save_safetensors(
str(out_file),
dict(tree_flatten(flux.flow.trainable_parameters())),
metadata={
"lora_rank": str(args.lora_rank),
"lora_blocks": str(args.lora_blocks),
},
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Finetune Flux to generate images with a specific subject"
)
parser.add_argument(
"--model",
default="dev",
choices=[
"dev",
"schnell",
],
help="Which flux model to train",
)
parser.add_argument(
"--guidance", type=float, default=4.0, help="The guidance factor to use."
)
parser.add_argument(
"--iterations",
type=int,
default=600,
help="How many iterations to train for",
)
parser.add_argument(
"--batch-size",
type=int,
default=1,
help="The batch size to use when training the stable diffusion model",
)
parser.add_argument(
"--resolution",
type=lambda x: tuple(map(int, x.split("x"))),
default=(512, 512),
help="The resolution of the training images",
)
parser.add_argument(
"--num-augmentations",
type=int,
default=5,
help="Augment the images by random cropping and panning",
)
parser.add_argument(
"--progress-prompt",
required=True,
help="Use this prompt when generating images for evaluation",
)
parser.add_argument(
"--progress-steps",
type=int,
default=50,
help="Use this many steps when generating images for evaluation",
)
parser.add_argument(
"--progress-every",
type=int,
default=50,
help="Generate images every PROGRESS_EVERY steps",
)
parser.add_argument(
"--checkpoint-every",
type=int,
default=50,
help="Save the model every CHECKPOINT_EVERY steps",
)
parser.add_argument(
"--lora-blocks",
type=int,
default=-1,
help="Train the last LORA_BLOCKS transformer blocks",
)
parser.add_argument(
"--lora-rank", type=int, default=8, help="LoRA rank for finetuning"
)
parser.add_argument(
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
)
parser.add_argument(
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
)
parser.add_argument(
"--grad-accumulate",
type=int,
default=4,
help="Accumulate gradients for that many iterations before applying them",
)
parser.add_argument(
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
)
parser.add_argument("dataset")
args = parser.parse_args()
# Load the model and set it up for LoRA training. We use the same random
# state when creating the LoRA layers so all workers will have the same
# initial weights.
mx.random.seed(0x0F0F0F0F)
flux = FluxPipeline("flux-" + args.model)
flux.flow.freeze()
flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks)
# Reset the seed to a different seed per worker if we are in distributed
# mode so that each worker is working on different data, diffusion step and
# random noise.
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
# Report how many parameters we are training
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
# Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation.
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
cosine = optim.cosine_decay(
args.learning_rate, args.iterations // args.grad_accumulate
)
lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps])
optimizer = optim.Adam(learning_rate=lr_schedule)
state = [flux.flow.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def single_step(x, t5_feat, clip_feat, guidance):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_grads(x, t5_feat, clip_feat, guidance):
return nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
@partial(mx.compile, inputs=state, outputs=state)
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
return loss, grads
@partial(mx.compile, inputs=state, outputs=state)
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
x, t5_feat, clip_feat, guidance
)
grads = tree_map(
lambda a, b: (a + b) / args.grad_accumulate,
prev_grads,
grads,
)
grads = average_gradients(grads)
optimizer.update(flux.flow, grads)
return loss
# We simply route to the appropriate step based on whether we have
# gradients from a previous step and whether we should be performing an
# update or simply computing and accumulating gradients in this step.
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
if prev_grads is None:
if perform_step:
return single_step(x, t5_feat, clip_feat, guidance), None
else:
return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
else:
if perform_step:
return (
grad_accumulate_and_step(
x, t5_feat, clip_feat, guidance, prev_grads
),
None,
)
else:
return compute_loss_and_accumulate_grads(
x, t5_feat, clip_feat, guidance, prev_grads
)
print("Create the training dataset.", flush=True)
dataset = FinetuningDataset(flux, args)
dataset.encode_images()
dataset.encode_prompts()
guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype)
# An initial generation to compare
generate_progress_images(0, flux, args)
grads = None
losses = []
tic = time.time()
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
mx.eval(loss, grads, state)
losses.append(loss.item())
if (i + 1) % 10 == 0:
toc = time.time()
peak_mem = mx.metal.get_peak_memory() / 1024**3
print(
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
f"It/s: {10 / (toc - tic):.3f} "
f"Peak mem: {peak_mem:.3f} GB",
flush=True,
)
if (i + 1) % args.progress_every == 0:
generate_progress_images(i + 1, flux, args)
if (i + 1) % args.checkpoint_every == 0:
save_adapters(i + 1, flux, args)
if (i + 1) % 10 == 0:
losses = []
tic = time.time()

248
flux/flux/__init__.py Normal file
View File

@ -0,0 +1,248 @@
# Copyright © 2024 Apple Inc.
import math
import time
from typing import Tuple
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
from tqdm import tqdm
from .lora import LoRALinear
from .sampler import FluxSampler
from .utils import (
load_ae,
load_clip,
load_clip_tokenizer,
load_flow_model,
load_t5,
load_t5_tokenizer,
)
class FluxPipeline:
def __init__(self, name: str, t5_padding: bool = True):
self.dtype = mx.bfloat16
self.name = name
self.t5_padding = t5_padding
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.sampler = FluxSampler(name)
def ensure_models_are_loaded(self):
mx.eval(
self.ae.parameters(),
self.flow.parameters(),
self.clip.parameters(),
self.t5.parameters(),
)
def reload_text_encoders(self):
self.t5 = load_t5(self.name)
self.clip = load_clip(self.name)
def tokenize(self, text):
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
clip_tokens = self.clip_tokenizer.encode(text)
return t5_tokens, clip_tokens
def _prepare_latent_images(self, x):
b, h, w, c = x.shape
# Pack the latent image to 2x2 patches
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
# Create positions ids used to positionally encode each patch. Due to
# the way RoPE works, this results in an interesting positional
# encoding where parts of the feature are holding different positional
# information. Namely, the first part holds information independent of
# the spatial position (hence 0s), the 2nd part holds vertical spatial
# information and the last one horizontal.
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
x_ids = mx.stack([i, j, k], axis=-1)
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
return x, x_ids
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
# Prepare the text features
txt = self.t5(t5_tokens)
if len(txt) == 1 and n_images > 1:
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
vec = self.clip(clip_tokens).pooled_output
if len(vec) == 1 and n_images > 1:
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
return txt, txt_ids, vec
def _denoising_loop(
self,
x_t,
x_ids,
txt,
txt_ids,
vec,
num_steps: int = 35,
guidance: float = 4.0,
start: float = 1,
stop: float = 0,
):
B = len(x_t)
def scalar(x):
return mx.full((B,), x, dtype=self.dtype)
guidance = scalar(guidance)
timesteps = self.sampler.timesteps(
num_steps,
x_t.shape[1],
start=start,
stop=stop,
)
for i in range(num_steps):
t = timesteps[i]
t_prev = timesteps[i + 1]
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=scalar(t),
guidance=guidance,
)
x_t = self.sampler.step(pred, x_t, t, t_prev)
yield x_t
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
):
# Set the PRNG state
if seed is not None:
mx.random.seed(seed)
# Create the latent variables
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
x_T, x_ids = self._prepare_latent_images(x_T)
# Get the conditioning
t5_tokens, clip_tokens = self.tokenize(text)
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
# Yield the conditioning for controlled evaluation by the caller
yield (x_T, x_ids, txt, txt_ids, vec)
# Yield the latent sequences from the denoising loop
yield from self._denoising_loop(
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
)
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
h, w = latent_size
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
x = self.ae.decode(x)
return mx.clip(x + 1, 0, 2) * 0.5
def generate_images(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
reload_text_encoders: bool = True,
progress: bool = True,
):
latents = self.generate_latents(
text, n_images, num_steps, guidance, latent_size, seed
)
mx.eval(next(latents))
if reload_text_encoders:
self.reload_text_encoders()
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
mx.eval(x_t)
images = []
for i in tqdm(range(len(x_t)), disable=not progress):
images.append(self.decode(x_t[i : i + 1]))
mx.eval(images[-1])
images = mx.concatenate(images, axis=0)
mx.eval(images)
return images
def training_loss(
self,
x_0: mx.array,
t5_features: mx.array,
clip_features: mx.array,
guidance: mx.array,
):
# Get the text conditioning
txt = t5_features
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
vec = clip_features
# Prepare the latent input
x_0, x_ids = self._prepare_latent_images(x_0)
# Forward process
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
x_t = self.sampler.add_noise(x_0, t, noise=eps)
x_t = mx.stop_gradient(x_t)
# Do the denoising
pred = self.flow(
img=x_t,
img_ids=x_ids,
txt=txt,
txt_ids=txt_ids,
y=vec,
timesteps=t,
guidance=guidance,
)
return (pred + x_0 - eps).square().mean()
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
"""Swap the linear layers in the transformer blocks with LoRA layers."""
all_blocks = self.flow.double_blocks + self.flow.single_blocks
all_blocks.reverse()
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
for i, block in zip(range(num_blocks), all_blocks):
loras = []
for name, module in block.named_modules():
if isinstance(module, nn.Linear):
loras.append((name, LoRALinear.from_base(module, r=rank)))
block.update_modules(tree_unflatten(loras))
def fuse_lora_layers(self):
fused_layers = []
for name, module in self.flow.named_modules():
if isinstance(module, LoRALinear):
fused_layers.append((name, module.fuse()))
self.flow.update_modules(tree_unflatten(fused_layers))

357
flux/flux/autoencoder.py Normal file
View File

@ -0,0 +1,357 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import List
import mlx.core as mx
import mlx.nn as nn
from mlx.nn.layers.upsample import upsample_nearest
@dataclass
class AutoEncoderParams:
resolution: int
in_channels: int
ch: int
out_ch: int
ch_mult: List[int]
num_res_blocks: int
z_channels: int
scale_factor: float
shift_factor: float
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.in_channels = in_channels
self.norm = nn.GroupNorm(
num_groups=32,
dims=in_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.q = nn.Linear(in_channels, in_channels)
self.k = nn.Linear(in_channels, in_channels)
self.v = nn.Linear(in_channels, in_channels)
self.proj_out = nn.Linear(in_channels, in_channels)
def __call__(self, x: mx.array) -> mx.array:
B, H, W, C = x.shape
y = x.reshape(B, 1, -1, C)
y = self.norm(y)
q = self.q(y)
k = self.k(y)
v = self.v(y)
y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5))
y = self.proj_out(y)
return x + y.reshape(B, H, W, C)
class ResnetBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.norm1 = nn.GroupNorm(
num_groups=32,
dims=in_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
self.norm2 = nn.GroupNorm(
num_groups=32,
dims=out_channels,
eps=1e-6,
affine=True,
pytorch_compatible=True,
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if self.in_channels != self.out_channels:
self.nin_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x):
h = x
h = self.norm1(h)
h = nn.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = nn.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
x = self.nin_shortcut(x)
return x + h
class Downsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=2, padding=0
)
def __call__(self, x: mx.array):
x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)])
x = self.conv(x)
return x
class Upsample(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x: mx.array):
x = upsample_nearest(x, (2, 2))
x = self.conv(x)
return x
class Encoder(nn.Module):
def __init__(
self,
resolution: int,
in_channels: int,
ch: int,
ch_mult: list[int],
num_res_blocks: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
# downsampling
self.conv_in = nn.Conv2d(
in_channels, self.ch, kernel_size=3, stride=1, padding=1
)
curr_res = resolution
in_ch_mult = (1,) + tuple(ch_mult)
self.in_ch_mult = in_ch_mult
self.down = []
block_in = self.ch
for i_level in range(self.num_resolutions):
block = []
attn = [] # TODO: Remove the attn, nobody appends anything to it
block_in = ch * in_ch_mult[i_level]
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
down = {}
down["block"] = block
down["attn"] = attn
if i_level != self.num_resolutions - 1:
down["downsample"] = Downsample(block_in)
curr_res = curr_res // 2
self.down.append(down)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid["attn_1"] = AttnBlock(block_in)
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
# end
self.norm_out = nn.GroupNorm(
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
)
self.conv_out = nn.Conv2d(
block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x: mx.array):
hs = [self.conv_in(x)]
for i_level in range(self.num_resolutions):
for i_block in range(self.num_res_blocks):
h = self.down[i_level]["block"][i_block](hs[-1])
# TODO: Remove the attn
if len(self.down[i_level]["attn"]) > 0:
h = self.down[i_level]["attn"][i_block](h)
hs.append(h)
if i_level != self.num_resolutions - 1:
hs.append(self.down[i_level]["downsample"](hs[-1]))
# middle
h = hs[-1]
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# end
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return h
class Decoder(nn.Module):
def __init__(
self,
ch: int,
out_ch: int,
ch_mult: list[int],
num_res_blocks: int,
in_channels: int,
resolution: int,
z_channels: int,
):
super().__init__()
self.ch = ch
self.num_resolutions = len(ch_mult)
self.num_res_blocks = num_res_blocks
self.resolution = resolution
self.in_channels = in_channels
self.ffactor = 2 ** (self.num_resolutions - 1)
# compute in_ch_mult, block_in and curr_res at lowest res
block_in = ch * ch_mult[self.num_resolutions - 1]
curr_res = resolution // 2 ** (self.num_resolutions - 1)
self.z_shape = (1, z_channels, curr_res, curr_res)
# z to block_in
self.conv_in = nn.Conv2d(
z_channels, block_in, kernel_size=3, stride=1, padding=1
)
# middle
self.mid = {}
self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
self.mid["attn_1"] = AttnBlock(block_in)
self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in)
# upsampling
self.up = []
for i_level in reversed(range(self.num_resolutions)):
block = []
attn = [] # TODO: Remove the attn, nobody appends anything to it
block_out = ch * ch_mult[i_level]
for _ in range(self.num_res_blocks + 1):
block.append(ResnetBlock(in_channels=block_in, out_channels=block_out))
block_in = block_out
up = {}
up["block"] = block
up["attn"] = attn
if i_level != 0:
up["upsample"] = Upsample(block_in)
curr_res = curr_res * 2
self.up.insert(0, up) # prepend to get consistent order
# end
self.norm_out = nn.GroupNorm(
num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1)
def __call__(self, z: mx.array):
# z to block_in
h = self.conv_in(z)
# middle
h = self.mid["block_1"](h)
h = self.mid["attn_1"](h)
h = self.mid["block_2"](h)
# upsampling
for i_level in reversed(range(self.num_resolutions)):
for i_block in range(self.num_res_blocks + 1):
h = self.up[i_level]["block"][i_block](h)
# TODO: Remove the attn
if len(self.up[i_level]["attn"]) > 0:
h = self.up[i_level]["attn"][i_block](h)
if i_level != 0:
h = self.up[i_level]["upsample"](h)
# end
h = self.norm_out(h)
h = nn.silu(h)
h = self.conv_out(h)
return h
class DiagonalGaussian(nn.Module):
def __call__(self, z: mx.array):
mean, logvar = mx.split(z, 2, axis=-1)
if self.training:
std = mx.exp(0.5 * logvar)
eps = mx.random.normal(shape=z.shape, dtype=z.dtype)
return mean + std * eps
else:
return mean
class AutoEncoder(nn.Module):
def __init__(self, params: AutoEncoderParams):
super().__init__()
self.encoder = Encoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.decoder = Decoder(
resolution=params.resolution,
in_channels=params.in_channels,
ch=params.ch,
out_ch=params.out_ch,
ch_mult=params.ch_mult,
num_res_blocks=params.num_res_blocks,
z_channels=params.z_channels,
)
self.reg = DiagonalGaussian()
self.scale_factor = params.scale_factor
self.shift_factor = params.shift_factor
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if w.ndim == 4:
w = w.transpose(0, 2, 3, 1)
w = w.reshape(-1).reshape(w.shape)
if w.shape[1:3] == (1, 1):
w = w.squeeze((1, 2))
new_weights[k] = w
return new_weights
def encode(self, x: mx.array):
z = self.reg(self.encoder(x))
z = self.scale_factor * (z - self.shift_factor)
return z
def decode(self, z: mx.array):
z = z / self.scale_factor + self.shift_factor
return self.decoder(z)
def __call__(self, x: mx.array):
return self.decode(self.encode(x))

154
flux/flux/clip.py Normal file
View File

@ -0,0 +1,154 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
hidden_act: str = "quick_gelu"
@classmethod
def from_dict(cls, config):
return cls(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
hidden_act=config["hidden_act"],
)
@dataclass
class CLIPOutput:
# The last_hidden_state indexed at the EOS token and possibly projected if
# the model has a projection layer
pooled_output: Optional[mx.array] = None
# The full sequence output of the transformer after the final layernorm
last_hidden_state: Optional[mx.array] = None
# A list of hidden states corresponding to the outputs of the transformer layers
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def sanitize(self, weights):
new_weights = {}
for key, w in weights.items():
# Remove prefixes
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
new_weights[key] = w
return new_weights
def __call__(self, x):
# Extract some shapes
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = self._get_mask(N, x.dtype)
hidden_states = []
for l in self.layers:
x = l(x, mask)
hidden_states.append(x)
# Apply the final layernorm and return
x = self.final_layer_norm(x)
last_hidden_state = x
# Select the EOS token
pooled_output = x[mx.arange(len(x)), eos_tokens]
return CLIPOutput(
pooled_output=pooled_output,
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
)

302
flux/flux/layers.py Normal file
View File

@ -0,0 +1,302 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from functools import partial
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
def _rope(pos: mx.array, dim: int, theta: float):
scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim
omega = 1.0 / (theta**scale)
x = pos[..., None] * omega
cosx = mx.cos(x)
sinx = mx.sin(x)
pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1)
pe = pe.reshape(*pe.shape[:-1], 2, 2)
return pe
@partial(mx.compile, shapeless=True)
def _ab_plus_cd(a, b, c, d):
return a * b + c * d
def _apply_rope(x, pe):
s = x.shape
x = x.reshape(*s[:-1], -1, 1, 2)
x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1])
return x.reshape(s)
def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array):
B, H, L, D = q.shape
q = _apply_rope(q, pe)
k = _apply_rope(k, pe)
x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5))
return x.transpose(0, 2, 1, 3).reshape(B, L, -1)
def timestep_embedding(
t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0
):
half = dim // 2
freqs = mx.arange(0, half, dtype=mx.float32) / half
freqs = freqs * (-math.log(max_period))
freqs = mx.exp(freqs)
x = (time_factor * t)[:, None] * freqs[None]
x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1)
return x.astype(t.dtype)
class EmbedND(nn.Module):
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
super().__init__()
self.dim = dim
self.theta = theta
self.axes_dim = axes_dim
def __call__(self, ids: mx.array):
n_axes = ids.shape[-1]
pe = mx.concatenate(
[_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
axis=-3,
)
return pe[:, None]
class MLPEmbedder(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int):
super().__init__()
self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True)
self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True)
def __call__(self, x: mx.array) -> mx.array:
return self.out_layer(nn.silu(self.in_layer(x)))
class QKNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.query_norm = nn.RMSNorm(dim)
self.key_norm = nn.RMSNorm(dim)
def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]:
return self.query_norm(q), self.key_norm(k)
class SelfAttention(nn.Module):
def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.norm = QKNorm(head_dim)
self.proj = nn.Linear(dim, dim)
def __call__(self, x: mx.array, pe: mx.array) -> mx.array:
H = self.num_heads
B, L, _ = x.shape
qkv = self.qkv(x)
q, k, v = mx.split(qkv, 3, axis=-1)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
x = _attention(q, k, v, pe)
x = self.proj(x)
return x
@dataclass
class ModulationOut:
shift: mx.array
scale: mx.array
gate: mx.array
class Modulation(nn.Module):
def __init__(self, dim: int, double: bool):
super().__init__()
self.is_double = double
self.multiplier = 6 if double else 3
self.lin = nn.Linear(dim, self.multiplier * dim, bias=True)
def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]:
x = self.lin(nn.silu(x))
xs = mx.split(x[:, None, :], self.multiplier, axis=-1)
mod1 = ModulationOut(*xs[:3])
mod2 = ModulationOut(*xs[3:]) if self.is_double else None
return mod1, mod2
class DoubleStreamBlock(nn.Module):
def __init__(
self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False
):
super().__init__()
mlp_hidden_dim = int(hidden_size * mlp_ratio)
self.num_heads = num_heads
self.hidden_size = hidden_size
self.img_mod = Modulation(hidden_size, double=True)
self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.img_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
self.txt_mod = Modulation(hidden_size, double=True)
self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_attn = SelfAttention(
dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias
)
self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.txt_mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
nn.GELU(approx="tanh"),
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
)
def __call__(
self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array
) -> Tuple[mx.array, mx.array]:
B, L, _ = img.shape
_, S, _ = txt.shape
H = self.num_heads
img_mod1, img_mod2 = self.img_mod(vec)
txt_mod1, txt_mod2 = self.txt_mod(vec)
# prepare image for attention
img_modulated = self.img_norm1(img)
img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift
img_qkv = self.img_attn.qkv(img_modulated)
img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1)
img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
img_q, img_k = self.img_attn.norm(img_q, img_k)
# prepare txt for attention
txt_modulated = self.txt_norm1(txt)
txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift
txt_qkv = self.txt_attn.qkv(txt_modulated)
txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1)
txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3)
txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k)
# run actual attention
q = mx.concatenate([txt_q, img_q], axis=2)
k = mx.concatenate([txt_k, img_k], axis=2)
v = mx.concatenate([txt_v, img_v], axis=2)
attn = _attention(q, k, v, pe)
txt_attn, img_attn = mx.split(attn, [S], axis=1)
# calculate the img bloks
img = img + img_mod1.gate * self.img_attn.proj(img_attn)
img = img + img_mod2.gate * self.img_mlp(
(1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift
)
# calculate the txt bloks
txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn)
txt = txt + txt_mod2.gate * self.txt_mlp(
(1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift
)
return img, txt
class SingleStreamBlock(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
mlp_ratio: float = 4.0,
qk_scale: Optional[float] = None,
):
super().__init__()
self.hidden_dim = hidden_size
self.num_heads = num_heads
head_dim = hidden_size // num_heads
self.scale = qk_scale or head_dim**-0.5
self.mlp_hidden_dim = int(hidden_size * mlp_ratio)
# qkv and mlp_in
self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim)
# proj and mlp_out
self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size)
self.norm = QKNorm(head_dim)
self.hidden_size = hidden_size
self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.mlp_act = nn.GELU(approx="tanh")
self.modulation = Modulation(hidden_size, double=False)
def __call__(self, x: mx.array, vec: mx.array, pe: mx.array):
B, L, _ = x.shape
H = self.num_heads
mod, _ = self.modulation(vec)
x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift
q, k, v, mlp = mx.split(
self.linear1(x_mod),
[self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size],
axis=-1,
)
q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3)
q, k = self.norm(q, k)
# compute attention
y = _attention(q, k, v, pe)
# compute activation in mlp stream, cat again and run second linear layer
y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2))
return x + mod.gate * y
class LastLayer(nn.Module):
def __init__(self, hidden_size: int, patch_size: int, out_channels: int):
super().__init__()
self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6)
self.linear = nn.Linear(
hidden_size, patch_size * patch_size * out_channels, bias=True
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)
)
def __call__(self, x: mx.array, vec: mx.array):
shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1)
x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :]
x = self.linear(x)
return x

76
flux/flux/lora.py Normal file
View File

@ -0,0 +1,76 @@
# Copyright © 2024 Apple Inc.
import math
import mlx.core as mx
import mlx.nn as nn
class LoRALinear(nn.Module):
@staticmethod
def from_base(
linear: nn.Linear,
r: int = 8,
dropout: float = 0.0,
scale: float = 1.0,
):
output_dims, input_dims = linear.weight.shape
lora_lin = LoRALinear(
input_dims=input_dims,
output_dims=output_dims,
r=r,
dropout=dropout,
scale=scale,
)
lora_lin.linear = linear
return lora_lin
def fuse(self):
linear = self.linear
bias = "bias" in linear
weight = linear.weight
dtype = weight.dtype
output_dims, input_dims = weight.shape
fused_linear = nn.Linear(input_dims, output_dims, bias=bias)
lora_b = self.scale * self.lora_b.T
lora_a = self.lora_a.T
fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype)
if bias:
fused_linear.bias = linear.bias
return fused_linear
def __init__(
self,
input_dims: int,
output_dims: int,
r: int = 8,
dropout: float = 0.0,
scale: float = 1.0,
bias: bool = False,
):
super().__init__()
# Regular linear layer weights
self.linear = nn.Linear(input_dims, output_dims, bias=bias)
self.dropout = nn.Dropout(p=dropout)
# Scale for low-rank update
self.scale = scale
# Low rank lora weights
scale = 1 / math.sqrt(input_dims)
self.lora_a = mx.random.uniform(
low=-scale,
high=scale,
shape=(input_dims, r),
)
self.lora_b = mx.zeros(shape=(r, output_dims))
def __call__(self, x):
y = self.linear(x)
z = (self.dropout(x) @ self.lora_a) @ self.lora_b
return y + (self.scale * z).astype(x.dtype)

134
flux/flux/model.py Normal file
View File

@ -0,0 +1,134 @@
# Copyright © 2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .layers import (
DoubleStreamBlock,
EmbedND,
LastLayer,
MLPEmbedder,
SingleStreamBlock,
timestep_embedding,
)
@dataclass
class FluxParams:
in_channels: int
vec_in_dim: int
context_in_dim: int
hidden_size: int
mlp_ratio: float
num_heads: int
depth: int
depth_single_blocks: int
axes_dim: list[int]
theta: int
qkv_bias: bool
guidance_embed: bool
class Flux(nn.Module):
def __init__(self, params: FluxParams):
super().__init__()
self.params = params
self.in_channels = params.in_channels
self.out_channels = self.in_channels
if params.hidden_size % params.num_heads != 0:
raise ValueError(
f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}"
)
pe_dim = params.hidden_size // params.num_heads
if sum(params.axes_dim) != pe_dim:
raise ValueError(
f"Got {params.axes_dim} but expected positional dim {pe_dim}"
)
self.hidden_size = params.hidden_size
self.num_heads = params.num_heads
self.pe_embedder = EmbedND(
dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim
)
self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True)
self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size)
self.guidance_in = (
MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size)
if params.guidance_embed
else nn.Identity()
)
self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size)
self.double_blocks = [
DoubleStreamBlock(
self.hidden_size,
self.num_heads,
mlp_ratio=params.mlp_ratio,
qkv_bias=params.qkv_bias,
)
for _ in range(params.depth)
]
self.single_blocks = [
SingleStreamBlock(
self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio
)
for _ in range(params.depth_single_blocks)
]
self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
if k.endswith(".scale"):
k = k[:-6] + ".weight"
for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]:
if f".{seq}." in k:
k = k.replace(f".{seq}.", f".{seq}.layers.")
break
new_weights[k] = w
return new_weights
def __call__(
self,
img: mx.array,
img_ids: mx.array,
txt: mx.array,
txt_ids: mx.array,
timesteps: mx.array,
y: mx.array,
guidance: Optional[mx.array] = None,
) -> mx.array:
if img.ndim != 3 or txt.ndim != 3:
raise ValueError("Input img and txt tensors must have 3 dimensions.")
img = self.img_in(img)
vec = self.time_in(timestep_embedding(timesteps, 256))
if self.params.guidance_embed:
if guidance is None:
raise ValueError(
"Didn't get guidance strength for guidance distilled model."
)
vec = vec + self.guidance_in(timestep_embedding(guidance, 256))
vec = vec + self.vector_in(y)
txt = self.txt_in(txt)
ids = mx.concatenate([txt_ids, img_ids], axis=1)
pe = self.pe_embedder(ids).astype(img.dtype)
for block in self.double_blocks:
img, txt = block(img=img, txt=txt, vec=vec, pe=pe)
img = mx.concatenate([txt, img], axis=1)
for block in self.single_blocks:
img = block(img, vec=vec, pe=pe)
img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec)
return img

56
flux/flux/sampler.py Normal file
View File

@ -0,0 +1,56 @@
# Copyright © 2024 Apple Inc.
import math
from functools import lru_cache
import mlx.core as mx
class FluxSampler:
def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5):
self._base_shift = base_shift
self._max_shift = max_shift
self._schnell = "schnell" in name
def _time_shift(self, x, t):
x1, x2 = 256, 4096
t1, t2 = self._base_shift, self._max_shift
exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1)
t = exp_mu / (exp_mu + (1 / t - 1))
return t
@lru_cache
def timesteps(
self, num_steps, image_sequence_length, start: float = 1, stop: float = 0
):
t = mx.linspace(start, stop, num_steps + 1)
if self._schnell:
t = self._time_shift(image_sequence_length, t)
return t.tolist()
def random_timesteps(self, B, L, dtype=mx.float32, key=None):
if self._schnell:
# TODO: Should we upweigh 1 and 0.75?
t = mx.random.randint(1, 5, shape=(B,), key=key)
t = t.astype(dtype) / 4
else:
t = mx.random.uniform(shape=(B,), dtype=dtype, key=key)
t = self._time_shift(L, t)
return t
def sample_prior(self, shape, dtype=mx.float32, key=None):
return mx.random.normal(shape, dtype=dtype, key=key)
def add_noise(self, x, t, noise=None, key=None):
noise = (
noise
if noise is not None
else mx.random.normal(x.shape, dtype=x.dtype, key=key)
)
return x * (1 - t) + t * noise
def step(self, pred, x_t, t, t_prev):
return x_t + (t_prev - t) * pred

244
flux/flux/t5.py Normal file
View File

@ -0,0 +1,244 @@
# Copyright © 2024 Apple Inc.
import math
from dataclasses import dataclass
from typing import List, Optional, Tuple
import mlx.core as mx
import mlx.nn as nn
_SHARED_REPLACEMENT_PATTERNS = [
(".block.", ".layers."),
(".k.", ".key_proj."),
(".o.", ".out_proj."),
(".q.", ".query_proj."),
(".v.", ".value_proj."),
("shared.", "wte."),
("lm_head.", "lm_head.linear."),
(".layer.0.layer_norm.", ".ln1."),
(".layer.1.layer_norm.", ".ln2."),
(".layer.2.layer_norm.", ".ln3."),
(".final_layer_norm.", ".ln."),
(
"layers.0.layer.0.SelfAttention.relative_attention_bias.",
"relative_attention_bias.embeddings.",
),
]
_ENCODER_REPLACEMENT_PATTERNS = [
(".layer.0.SelfAttention.", ".attention."),
(".layer.1.DenseReluDense.", ".dense."),
]
@dataclass
class T5Config:
vocab_size: int
num_layers: int
num_heads: int
relative_attention_num_buckets: int
d_kv: int
d_model: int
feed_forward_proj: str
tie_word_embeddings: bool
d_ff: Optional[int] = None
num_decoder_layers: Optional[int] = None
relative_attention_max_distance: int = 128
layer_norm_epsilon: float = 1e-6
@classmethod
def from_dict(cls, config):
return cls(
vocab_size=config["vocab_size"],
num_layers=config["num_layers"],
num_heads=config["num_heads"],
relative_attention_num_buckets=config["relative_attention_num_buckets"],
d_kv=config["d_kv"],
d_model=config["d_model"],
feed_forward_proj=config["feed_forward_proj"],
tie_word_embeddings=config["tie_word_embeddings"],
d_ff=config.get("d_ff", 4 * config["d_model"]),
num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]),
relative_attention_max_distance=config.get(
"relative_attention_max_distance", 128
),
layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6),
)
class RelativePositionBias(nn.Module):
def __init__(self, config: T5Config, bidirectional: bool):
self.bidirectional = bidirectional
self.num_buckets = config.relative_attention_num_buckets
self.max_distance = config.relative_attention_max_distance
self.n_heads = config.num_heads
self.embeddings = nn.Embedding(self.num_buckets, self.n_heads)
@staticmethod
def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance):
num_buckets = num_buckets // 2 if bidirectional else num_buckets
max_exact = num_buckets // 2
abspos = rpos.abs()
is_small = abspos < max_exact
scale = (num_buckets - max_exact) / math.log(max_distance / max_exact)
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
buckets = mx.where(is_small, abspos, buckets_large)
if bidirectional:
buckets = buckets + (rpos > 0) * num_buckets
else:
buckets = buckets * (rpos < 0)
return buckets
def __call__(self, query_length: int, key_length: int, offset: int = 0):
"""Compute binned relative position bias"""
context_position = mx.arange(offset, query_length)[:, None]
memory_position = mx.arange(key_length)[None, :]
# shape (query_length, key_length)
relative_position = memory_position - context_position
relative_position_bucket = self._relative_position_bucket(
relative_position,
bidirectional=self.bidirectional,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
# shape (query_length, key_length, num_heads)
values = self.embeddings(relative_position_bucket)
# shape (num_heads, query_length, key_length)
return values.transpose(2, 0, 1)
class MultiHeadAttention(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
inner_dim = config.d_kv * config.num_heads
self.num_heads = config.num_heads
self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False)
self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False)
def __call__(
self,
queries: mx.array,
keys: mx.array,
values: mx.array,
mask: Optional[mx.array],
cache: Optional[Tuple[mx.array, mx.array]] = None,
) -> [mx.array, Tuple[mx.array, mx.array]]:
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
num_heads = self.num_heads
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
key_cache, value_cache = cache
keys = mx.concatenate([key_cache, keys], axis=3)
values = mx.concatenate([value_cache, values], axis=2)
values_hat = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
)
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
return self.out_proj(values_hat), (keys, values)
class DenseActivation(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
mlp_dims = config.d_ff or config.d_model * 4
self.gated = config.feed_forward_proj.startswith("gated")
if self.gated:
self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False)
else:
self.wi = nn.Linear(config.d_model, mlp_dims, bias=False)
self.wo = nn.Linear(mlp_dims, config.d_model, bias=False)
activation = config.feed_forward_proj.removeprefix("gated-")
if activation == "relu":
self.act = nn.relu
elif activation == "gelu":
self.act = nn.gelu
elif activation == "silu":
self.act = nn.silu
else:
raise ValueError(f"Unknown activation: {activation}")
def __call__(self, x):
if self.gated:
hidden_act = self.act(self.wi_0(x))
hidden_linear = self.wi_1(x)
x = hidden_act * hidden_linear
else:
x = self.act(self.wi(x))
return self.wo(x)
class TransformerEncoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(self, x, mask):
y = self.ln1(x)
y, _ = self.attention(y, y, y, mask=mask)
x = x + y
y = self.ln2(x)
y = self.dense(y)
return x + y
class TransformerEncoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.layers = [
TransformerEncoderLayer(config) for i in range(config.num_layers)
]
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=True)
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
pos_bias = pos_bias.astype(x.dtype)
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class T5Encoder(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)
self.encoder = TransformerEncoder(config)
def sanitize(self, weights):
new_weights = {}
for k, w in weights.items():
for old, new in _SHARED_REPLACEMENT_PATTERNS:
k = k.replace(old, new)
if k.startswith("encoder."):
for old, new in _ENCODER_REPLACEMENT_PATTERNS:
k = k.replace(old, new)
new_weights[k] = w
return new_weights
def __call__(self, inputs: mx.array):
return self.encoder(self.wte(inputs))

185
flux/flux/tokenizers.py Normal file
View File

@ -0,0 +1,185 @@
# Copyright © 2024 Apple Inc.
import mlx.core as mx
import regex
from sentencepiece import SentencePieceProcessor
class CLIPTokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab, max_length=77):
self.max_length = max_length
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(
unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))
)
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Hugging Face does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
if len(tokens) > self.max_length:
tokens = tokens[: self.max_length]
if append_eos:
tokens[-1] = self.eos_token
return tokens
def encode(self, text):
if not isinstance(text, list):
return self.encode([text])
tokens = self.tokenize(text)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([self.eos_token] * (length - len(t)))
return mx.array(tokens)
class T5Tokenizer:
def __init__(self, model_file, max_length=512):
self._tokenizer = SentencePieceProcessor(model_file)
self.max_length = max_length
@property
def pad(self):
try:
return self._tokenizer.id_to_piece(self.pad_token)
except IndexError:
return None
@property
def pad_token(self):
return self._tokenizer.pad_id()
@property
def bos(self):
try:
return self._tokenizer.id_to_piece(self.bos_token)
except IndexError:
return None
@property
def bos_token(self):
return self._tokenizer.bos_id()
@property
def eos(self):
try:
return self._tokenizer.id_to_piece(self.eos_token)
except IndexError:
return None
@property
def eos_token(self):
return self._tokenizer.eos_id()
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
tokens = self._tokenizer.encode(text)
if prepend_bos and self.bos_token >= 0:
tokens = [self.bos_token] + tokens
if append_eos and self.eos_token >= 0:
tokens.append(self.eos_token)
if pad and len(tokens) < self.max_length and self.pad_token >= 0:
tokens += [self.pad_token] * (self.max_length - len(tokens))
return tokens
def encode(self, text, pad=True):
if not isinstance(text, list):
return self.encode([text], pad=pad)
pad_token = self.pad_token if self.pad_token >= 0 else 0
tokens = self.tokenize(text, pad=pad)
length = max(len(t) for t in tokens)
for t in tokens:
t.extend([pad_token] * (length - len(t)))
return mx.array(tokens)

209
flux/flux/utils.py Normal file
View File

@ -0,0 +1,209 @@
# Copyright © 2024 Apple Inc.
import json
import os
from dataclasses import dataclass
from typing import Optional
import mlx.core as mx
from huggingface_hub import hf_hub_download
from .autoencoder import AutoEncoder, AutoEncoderParams
from .clip import CLIPTextModel, CLIPTextModelConfig
from .model import Flux, FluxParams
from .t5 import T5Config, T5Encoder
from .tokenizers import CLIPTokenizer, T5Tokenizer
@dataclass
class ModelSpec:
params: FluxParams
ae_params: AutoEncoderParams
ckpt_path: Optional[str]
ae_path: Optional[str]
repo_id: Optional[str]
repo_flow: Optional[str]
repo_ae: Optional[str]
configs = {
"flux-dev": ModelSpec(
repo_id="black-forest-labs/FLUX.1-dev",
repo_flow="flux1-dev.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_DEV"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=True,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
"flux-schnell": ModelSpec(
repo_id="black-forest-labs/FLUX.1-schnell",
repo_flow="flux1-schnell.safetensors",
repo_ae="ae.safetensors",
ckpt_path=os.getenv("FLUX_SCHNELL"),
params=FluxParams(
in_channels=64,
vec_in_dim=768,
context_in_dim=4096,
hidden_size=3072,
mlp_ratio=4.0,
num_heads=24,
depth=19,
depth_single_blocks=38,
axes_dim=[16, 56, 56],
theta=10_000,
qkv_bias=True,
guidance_embed=False,
),
ae_path=os.getenv("AE"),
ae_params=AutoEncoderParams(
resolution=256,
in_channels=3,
ch=128,
out_ch=3,
ch_mult=[1, 2, 4, 4],
num_res_blocks=2,
z_channels=16,
scale_factor=0.3611,
shift_factor=0.1159,
),
),
}
def load_flow_model(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ckpt_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_flow is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow)
# Make the model
model = Flux(configs[name].params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = model.sanitize(weights)
model.load_weights(list(weights.items()))
return model
def load_ae(name: str, hf_download: bool = True):
# Get the safetensors file to load
ckpt_path = configs[name].ae_path
# Download if needed
if (
ckpt_path is None
and configs[name].repo_id is not None
and configs[name].repo_ae is not None
and hf_download
):
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae)
# Make the autoencoder
ae = AutoEncoder(configs[name].ae_params)
# Load the checkpoint if needed
if ckpt_path is not None:
weights = mx.load(ckpt_path)
weights = ae.sanitize(weights)
ae.load_weights(list(weights.items()))
return ae
def load_clip(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json")
with open(config_path) as f:
config = CLIPTextModelConfig.from_dict(json.load(f))
# Make the clip text encoder
clip = CLIPTextModel(config)
# Load the weights
ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors")
weights = mx.load(ckpt_path)
weights = clip.sanitize(weights)
clip.load_weights(list(weights.items()))
return clip
def load_t5(name: str):
# Load the config
config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json")
with open(config_path) as f:
config = T5Config.from_dict(json.load(f))
# Make the T5 model
t5 = T5Encoder(config)
# Load the weights
model_index = hf_hub_download(
configs[name].repo_id, "text_encoder_2/model.safetensors.index.json"
)
weight_files = set()
with open(model_index) as f:
for _, w in json.load(f)["weight_map"].items():
weight_files.add(w)
weights = {}
for w in weight_files:
w = f"text_encoder_2/{w}"
w = hf_hub_download(configs[name].repo_id, w)
weights.update(mx.load(w))
weights = t5.sanitize(weights)
t5.load_weights(list(weights.items()))
return t5
def load_clip_tokenizer(name: str):
vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json")
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt")
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
def load_t5_tokenizer(name: str, pad: bool = True):
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)

7
flux/requirements.txt Normal file
View File

@ -0,0 +1,7 @@
mlx>=0.18.1
huggingface-hub
regex
numpy
tqdm
Pillow
sentencepiece

Binary file not shown.

After

Width:  |  Height:  |  Size: 754 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 423 KiB

BIN
flux/static/dog6.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 434 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 153 KiB

150
flux/txt2image.py Normal file
View File

@ -0,0 +1,150 @@
# Copyright © 2024 Apple Inc.
import argparse
import mlx.core as mx
import mlx.nn as nn
import numpy as np
from PIL import Image
from tqdm import tqdm
from flux import FluxPipeline
def to_latent_size(image_size):
h, w = image_size
h = ((h + 15) // 16) * 16
w = ((w + 15) // 16) * 16
if (h, w) != image_size:
print(
"Warning: The image dimensions need to be divisible by 16px. "
f"Changing size to {h}x{w}."
)
return (h // 8, w // 8)
def quantization_predicate(name, m):
return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0
def load_adapter(flux, adapter_file, fuse=False):
weights, lora_config = mx.load(adapter_file, return_metadata=True)
rank = int(lora_config["lora_rank"])
num_blocks = int(lora_config["lora_blocks"])
flux.linear_to_lora_layers(rank, num_blocks)
flux.flow.load_weights(list(weights.items()), strict=False)
if fuse:
flux.fuse_lora_layers()
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
parser.add_argument("--n-images", type=int, default=4)
parser.add_argument(
"--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512)
)
parser.add_argument("--steps", type=int)
parser.add_argument("--guidance", type=float, default=4.0)
parser.add_argument("--n-rows", type=int, default=1)
parser.add_argument("--decoding-batch-size", type=int, default=1)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--save-raw", action="store_true")
parser.add_argument("--seed", type=int)
parser.add_argument("--verbose", "-v", action="store_true")
parser.add_argument("--adapter")
parser.add_argument("--fuse-adapter", action="store_true")
parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false")
args = parser.parse_args()
# Load the models
flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding)
args.steps = args.steps or (50 if args.model == "dev" else 2)
if args.adapter:
load_adapter(flux, args.adapter, fuse=args.fuse_adapter)
if args.quantize:
nn.quantize(flux.flow, class_predicate=quantization_predicate)
nn.quantize(flux.t5, class_predicate=quantization_predicate)
nn.quantize(flux.clip, class_predicate=quantization_predicate)
if args.preload_models:
sd.ensure_models_are_loaded()
# Make the generator
latent_size = to_latent_size(args.image_size)
latents = flux.generate_latents(
args.prompt,
n_images=args.n_images,
num_steps=args.steps,
latent_size=latent_size,
guidance=args.guidance,
seed=args.seed,
)
# First we get and eval the conditioning
conditioning = next(latents)
mx.eval(conditioning)
peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the text encoders.
del flux.t5
del flux.clip
# Actual denoising loop
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# The following is not necessary but it may help in memory constrained
# systems by reusing the memory kept by the flow transformer.
del flux.flow
peak_mem_generation = mx.metal.get_peak_memory() / 1024**3
mx.metal.reset_peak_memory()
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size))
mx.eval(decoded[-1])
peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3
peak_mem_overall = max(
peak_mem_conditioning, peak_mem_generation, peak_mem_decoding
)
if args.save_raw:
*name, suffix = args.output.split(".")
name = ".".join(name)
x = mx.concatenate(decoded, axis=0)
x = (x * 255).astype(mx.uint8)
for i in range(len(x)):
im = Image.fromarray(np.array(x[i]))
im.save(".".join([name, str(i), suffix]))
else:
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB")
print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB")
print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")