diff --git a/flux/README.md b/flux/README.md new file mode 100644 index 00000000..33bebfd6 --- /dev/null +++ b/flux/README.md @@ -0,0 +1,185 @@ +FLUX +==== + +FLUX implementation in MLX. The implementation is ported directly from +[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux) +and the model weights are downloaded directly from the Hugging Face Hub. + +The goal of this example is to be clean, educational and to allow for +experimentation with finetuning FLUX models as well as adding extra +functionality such as in-/outpainting, guidance with custom losses etc. + +![MLX image](static/generated-mlx.png) +*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of +tron emanating futuristic technology with the word "MLX" in the center with +capital red letters.'* + +Installation +------------ + +The dependencies are minimal, namely: + +- `huggingface-hub` to download the checkpoints. +- `regex` for the tokenization +- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script +- `sentencepiece` for the T5 tokenizer + +You can install all of the above with the `requirements.txt` as follows: + + pip install -r requirements.txt + +Inference +--------- + +Inference in this example is similar to the stable diffusion example. The +classes to get you started are `FluxPipeline` from the `flux` module. + +```python +import mlx.core as mx +from flux import FluxPipeline + +# This will download all the weights from HF hub +flux = FluxPipeline("flux-schnell") + +# Make a generator that returns the latent variables from the reverse diffusion +# process +latent_generator = flux.generate_latents( + "A photo of an astronaut riding a horse on Mars", + num_steps=4, + latent_size=(32, 64), # 256x512 image +) + +# The first return value of the generator contains the conditioning and the +# random noise at the beginning of the diffusion process. +conditioning = next(latent_generator) +( + x_T, # The initial noise + x_positions, # The integer positions used for image positional encoding + t5_conditioning, # The T5 features from the text prompt + t5_positions, # Integer positions for text (normally all 0s) + clip_conditioning, # The clip text features from the text prompt +) = conditioning + +# Returning the conditioning as the first output from the generator allows us +# to unload T5 and clip before running the diffusion transformer. +mx.eval(conditioning) + +# Evaluate each diffusion step +for x_t in latent_generator: + mx.eval(x_t) + +# Note that we need to pass the latent size because it is collapsed and +# patchified in x_t and we need to unwrap it. +img = flux.decode(x_t, latent_size=(32, 64)) +``` + +The above are essentially the implementation of the `txt2image.py` script +except for some additional logic to quantize and/or load trained adapters. One +can use the script as follows: + +```shell +python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.' +``` + +### Experimental Options + +FLUX pads the prompt to a specific size of 512 tokens for the dev model and +256 for the schnell model. Not applying padding results in faster generation +but it is not clear how it may affect the generated images. To enable that +option in this example pass `--no-t5-padding` to the `txt2image.py` script or +instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`. + +Finetuning +---------- + +The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell +but ymmv) on a provided image dataset. The dataset folder must have an +`index.json` file with the following format: + +```json +{ + "data": [ + {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, + {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, + {"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, + ... + ] +} +``` + +The training script by default trains for 600 iterations with a batch size of +1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py +--help` for the list of hyperparameters you can tune. + +> [!Note] +> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and +> should reduce this number significantly. + +### Training Example + +This is a step-by-step finetuning example. We will be using the data from +[https://github.com/google/dreambooth](https://github.com/google/dreambooth). +In particular, we will use `dog6` which is a popular example for showcasing +dreambooth [^1]. + +The training images are the following 5 images [^2]: + +![dog6](static/dog6.png) + +We start by making the following `index.json` file and placing it in the same +folder as the images. + +```json +{ + "data": [ + {"image": "00.jpg", "text": "A photo of sks dog"}, + {"image": "01.jpg", "text": "A photo of sks dog"}, + {"image": "02.jpg", "text": "A photo of sks dog"}, + {"image": "03.jpg", "text": "A photo of sks dog"}, + {"image": "04.jpg", "text": "A photo of sks dog"} + ] +} +``` + +Subsequently we finetune FLUX using the following command: + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + path/to/dreambooth/dataset/dog6 +``` + +The training requires approximately 50GB of RAM and on an M2 Ultra it takes a +bit more than 1 hour. + +### Using the Adapter + +The adapters are saved in `mlx_output` and can be used directly by the +`txt2image.py` script. For instance, + +```shell +python txt2img.py --model dev --save-raw --image-size 512x512 --n-images 1 \ + --adapter mlx_output/mlx_output/0001200_adapters.safetensors \ + --fuse-adapter \ + --no-t5-padding \ + 'A photo of an sks dog lying on the sand at a beach in Greece' +``` + +generates an image that looks like the following, + +![dog image](static/dog-r4-g8-1200.png) + +and of course we can pass `--image-size 512x1024` to get larger images with +different aspect ratios, + +![wide dog image](static/dog-r4-g8-1200-512x1024.png) + +The arguments that are relevant to the adapters are of course `--adapter` and +`--fuse-adapter`. The first defines the path to an adapter to apply to the +model and the second fuses the adapter back into the model to get a bit more +speed during generation. + +[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details. +[^2]: The images are from unsplash by https://unsplash.com/@alvannee . diff --git a/flux/dreambooth.py b/flux/dreambooth.py new file mode 100644 index 00000000..4a4dbb08 --- /dev/null +++ b/flux/dreambooth.py @@ -0,0 +1,378 @@ +# Copyright © 2024 Apple Inc. + +import argparse +import json +import time +from functools import partial +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np +from mlx.nn.utils import average_gradients +from mlx.utils import tree_flatten, tree_map, tree_reduce +from PIL import Image +from tqdm import tqdm + +from flux import FluxPipeline + + +class FinetuningDataset: + def __init__(self, flux, args): + self.args = args + self.flux = flux + self.dataset_base = Path(args.dataset) + dataset_index = self.dataset_base / "index.json" + if not dataset_index.exists(): + raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset") + with open(dataset_index, "r") as f: + self.index = json.load(f) + + self.latents = [] + self.t5_features = [] + self.clip_features = [] + + def _random_crop_resize(self, img): + resolution = self.args.resolution + width, height = img.size + + a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist() + + # Random crop the input image between 0.8 to 1.0 of its original dimensions + crop_size = ( + max((0.8 + 0.2 * a) * width, resolution[0]), + max((0.8 + 0.2 * a) * height, resolution[1]), + ) + pan = (width - crop_size[0], height - crop_size[1]) + img = img.crop( + ( + pan[0] * b, + pan[1] * c, + crop_size[0] + pan[0] * b, + crop_size[1] + pan[1] * c, + ) + ) + + # Fit the largest rectangle with the ratio of resolution in the image + # rectangle. + width, height = crop_size + ratio = resolution[0] / resolution[1] + r1 = (height * ratio, height) + r2 = (width, width / ratio) + r = r1 if r1[0] <= width else r2 + img = img.crop( + ( + (width - r[0]) / 2, + (height - r[1]) / 2, + (width + r[0]) / 2, + (height + r[1]) / 2, + ) + ) + + # Finally resize the image to resolution + img = img.resize(resolution, Image.LANCZOS) + + return mx.array(np.array(img)) + + def encode_images(self): + """Encode the images in the latent space to prepare for training.""" + self.flux.ae.eval() + for sample in tqdm(self.index["data"]): + input_img = Image.open(self.dataset_base / sample["image"]) + for i in range(self.args.num_augmentations): + img = self._random_crop_resize(input_img) + img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1 + x_0 = self.flux.ae.encode(img[None]) + x_0 = x_0.astype(self.flux.dtype) + mx.eval(x_0) + self.latents.append(x_0) + + def encode_prompts(self): + """Pre-encode the prompts so that we don't recompute them during + training (doesn't allow finetuning the text encoders).""" + for sample in tqdm(self.index["data"]): + t5_tok, clip_tok = self.flux.tokenize([sample["text"]]) + t5_feat = self.flux.t5(t5_tok) + clip_feat = self.flux.clip(clip_tok).pooled_output + mx.eval(t5_feat, clip_feat) + self.t5_features.append(t5_feat) + self.clip_features.append(clip_feat) + + def iterate(self, batch_size): + xs = mx.concatenate(self.latents) + t5 = mx.concatenate(self.t5_features) + clip = mx.concatenate(self.clip_features) + mx.eval(xs, t5, clip) + n_aug = self.args.num_augmentations + while True: + x_indices = mx.random.permutation(len(self.latents)) + c_indices = x_indices // n_aug + for i in range(0, len(self.latents), batch_size): + x_i = x_indices[i : i + batch_size] + c_i = c_indices[i : i + batch_size] + yield xs[x_i], t5[c_i], clip[c_i] + + +def generate_progress_images(iteration, flux, args): + """Generate images to monitor the progress of the finetuning.""" + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out_file = out_dir / f"{iteration:07d}_progress.png" + print(f"Generating {str(out_file)}", flush=True) + + # Generate some images and arrange them in a grid + n_rows = 2 + n_images = 4 + x = flux.generate_images( + args.progress_prompt, + n_images, + args.progress_steps, + ) + x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(n_rows, B // n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(n_rows * H, B // n_rows * W, C) + x = mx.pad(x, [(4, 4), (4, 4), (0, 0)]) + x = (x * 255).astype(mx.uint8) + + # Save them to disc + im = Image.fromarray(np.array(x)) + im.save(out_file) + + +def save_adapters(iteration, flux, args): + out_dir = Path(args.output_dir) + out_dir.mkdir(parents=True, exist_ok=True) + out_file = out_dir / f"{iteration:07d}_adapters.safetensors" + print(f"Saving {str(out_file)}") + + mx.save_safetensors( + str(out_file), + dict(tree_flatten(flux.flow.trainable_parameters())), + metadata={ + "lora_rank": str(args.lora_rank), + "lora_blocks": str(args.lora_blocks), + }, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Finetune Flux to generate images with a specific subject" + ) + + parser.add_argument( + "--model", + default="dev", + choices=[ + "dev", + "schnell", + ], + help="Which flux model to train", + ) + parser.add_argument( + "--guidance", type=float, default=4.0, help="The guidance factor to use." + ) + parser.add_argument( + "--iterations", + type=int, + default=600, + help="How many iterations to train for", + ) + parser.add_argument( + "--batch-size", + type=int, + default=1, + help="The batch size to use when training the stable diffusion model", + ) + parser.add_argument( + "--resolution", + type=lambda x: tuple(map(int, x.split("x"))), + default=(512, 512), + help="The resolution of the training images", + ) + parser.add_argument( + "--num-augmentations", + type=int, + default=5, + help="Augment the images by random cropping and panning", + ) + parser.add_argument( + "--progress-prompt", + required=True, + help="Use this prompt when generating images for evaluation", + ) + parser.add_argument( + "--progress-steps", + type=int, + default=50, + help="Use this many steps when generating images for evaluation", + ) + parser.add_argument( + "--progress-every", + type=int, + default=50, + help="Generate images every PROGRESS_EVERY steps", + ) + parser.add_argument( + "--checkpoint-every", + type=int, + default=50, + help="Save the model every CHECKPOINT_EVERY steps", + ) + parser.add_argument( + "--lora-blocks", + type=int, + default=-1, + help="Train the last LORA_BLOCKS transformer blocks", + ) + parser.add_argument( + "--lora-rank", type=int, default=8, help="LoRA rank for finetuning" + ) + parser.add_argument( + "--warmup-steps", type=int, default=100, help="Learning rate warmup" + ) + parser.add_argument( + "--learning-rate", type=float, default="1e-4", help="Learning rate for training" + ) + parser.add_argument( + "--grad-accumulate", + type=int, + default=4, + help="Accumulate gradients for that many iterations before applying them", + ) + parser.add_argument( + "--output-dir", default="mlx_output", help="Folder to save the checkpoints in" + ) + + parser.add_argument("dataset") + + args = parser.parse_args() + + # Load the model and set it up for LoRA training. We use the same random + # state when creating the LoRA layers so all workers will have the same + # initial weights. + mx.random.seed(0x0F0F0F0F) + flux = FluxPipeline("flux-" + args.model) + flux.flow.freeze() + flux.linear_to_lora_layers(args.lora_rank, args.lora_blocks) + + # Reset the seed to a different seed per worker if we are in distributed + # mode so that each worker is working on different data, diffusion step and + # random noise. + mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank()) + + # Report how many parameters we are training + trainable_params = tree_reduce( + lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 + ) + print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True) + + # Set up the optimizer and training steps. The steps are a bit verbose to + # support gradient accumulation together with compilation. + warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps) + cosine = optim.cosine_decay( + args.learning_rate, args.iterations // args.grad_accumulate + ) + lr_schedule = optim.join_schedules([warmup, cosine], [args.warmup_steps]) + optimizer = optim.Adam(learning_rate=lr_schedule) + state = [flux.flow.state, optimizer.state, mx.random.state] + + @partial(mx.compile, inputs=state, outputs=state) + def single_step(x, t5_feat, clip_feat, guidance): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + x, t5_feat, clip_feat, guidance + ) + grads = average_gradients(grads) + optimizer.update(flux.flow, grads) + + return loss + + @partial(mx.compile, inputs=state, outputs=state) + def compute_loss_and_grads(x, t5_feat, clip_feat, guidance): + return nn.value_and_grad(flux.flow, flux.training_loss)( + x, t5_feat, clip_feat, guidance + ) + + @partial(mx.compile, inputs=state, outputs=state) + def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + x, t5_feat, clip_feat, guidance + ) + grads = tree_map(lambda a, b: a + b, prev_grads, grads) + return loss, grads + + @partial(mx.compile, inputs=state, outputs=state) + def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads): + loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)( + x, t5_feat, clip_feat, guidance + ) + grads = tree_map( + lambda a, b: (a + b) / args.grad_accumulate, + prev_grads, + grads, + ) + grads = average_gradients(grads) + optimizer.update(flux.flow, grads) + + return loss + + # We simply route to the appropriate step based on whether we have + # gradients from a previous step and whether we should be performing an + # update or simply computing and accumulating gradients in this step. + def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step): + if prev_grads is None: + if perform_step: + return single_step(x, t5_feat, clip_feat, guidance), None + else: + return compute_loss_and_grads(x, t5_feat, clip_feat, guidance) + else: + if perform_step: + return ( + grad_accumulate_and_step( + x, t5_feat, clip_feat, guidance, prev_grads + ), + None, + ) + else: + return compute_loss_and_accumulate_grads( + x, t5_feat, clip_feat, guidance, prev_grads + ) + + print("Create the training dataset.", flush=True) + dataset = FinetuningDataset(flux, args) + dataset.encode_images() + dataset.encode_prompts() + guidance = mx.full((args.batch_size,), args.guidance, dtype=flux.dtype) + + # An initial generation to compare + generate_progress_images(0, flux, args) + + grads = None + losses = [] + tic = time.time() + for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)): + loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0) + mx.eval(loss, grads, state) + losses.append(loss.item()) + + if (i + 1) % 10 == 0: + toc = time.time() + peak_mem = mx.metal.get_peak_memory() / 1024**3 + print( + f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} " + f"It/s: {10 / (toc - tic):.3f} " + f"Peak mem: {peak_mem:.3f} GB", + flush=True, + ) + + if (i + 1) % args.progress_every == 0: + generate_progress_images(i + 1, flux, args) + + if (i + 1) % args.checkpoint_every == 0: + save_adapters(i + 1, flux, args) + + if (i + 1) % 10 == 0: + losses = [] + tic = time.time() diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py new file mode 100644 index 00000000..8d39d605 --- /dev/null +++ b/flux/flux/__init__.py @@ -0,0 +1,248 @@ +# Copyright © 2024 Apple Inc. + +import math +import time +from typing import Tuple + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_unflatten +from tqdm import tqdm + +from .lora import LoRALinear +from .sampler import FluxSampler +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str, t5_padding: bool = True): + self.dtype = mx.bfloat16 + self.name = name + self.t5_padding = t5_padding + + self.ae = load_ae(name) + self.flow = load_flow_model(name) + self.clip = load_clip(name) + self.clip_tokenizer = load_clip_tokenizer(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.sampler = FluxSampler(name) + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def reload_text_encoders(self): + self.t5 = load_t5(self.name) + self.clip = load_clip(self.name) + + def tokenize(self, text): + t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding) + clip_tokens = self.clip_tokenizer.encode(text) + return t5_tokens, clip_tokens + + def _prepare_latent_images(self, x): + b, h, w, c = x.shape + + # Pack the latent image to 2x2 patches + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + + # Create positions ids used to positionally encode each patch. Due to + # the way RoPE works, this results in an interesting positional + # encoding where parts of the feature are holding different positional + # information. Namely, the first part holds information independent of + # the spatial position (hence 0s), the 2nd part holds vertical spatial + # information and the last one horizontal. + i = mx.zeros((h // 2, w // 2), dtype=mx.int32) + j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij") + x_ids = mx.stack([i, j, k], axis=-1) + x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0) + + return x, x_ids + + def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens): + # Prepare the text features + txt = self.t5(t5_tokens) + if len(txt) == 1 and n_images > 1: + txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:])) + txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + vec = self.clip(clip_tokens).pooled_output + if len(vec) == 1 and n_images > 1: + vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:])) + + return txt, txt_ids, vec + + def _denoising_loop( + self, + x_t, + x_ids, + txt, + txt_ids, + vec, + num_steps: int = 35, + guidance: float = 4.0, + start: float = 1, + stop: float = 0, + ): + B = len(x_t) + + def scalar(x): + return mx.full((B,), x, dtype=self.dtype) + + guidance = scalar(guidance) + timesteps = self.sampler.timesteps( + num_steps, + x_t.shape[1], + start=start, + stop=stop, + ) + for i in range(num_steps): + t = timesteps[i] + t_prev = timesteps[i + 1] + + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=scalar(t), + guidance=guidance, + ) + x_t = self.sampler.step(pred, x_t, t, t_prev) + + yield x_t + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + if seed is not None: + mx.random.seed(seed) + + # Create the latent variables + x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype) + x_T, x_ids = self._prepare_latent_images(x_T) + + # Get the conditioning + t5_tokens, clip_tokens = self.tokenize(text) + txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens) + + # Yield the conditioning for controlled evaluation by the caller + yield (x_T, x_ids, txt, txt_ids, vec) + + # Yield the latent sequences from the denoising loop + yield from self._denoising_loop( + x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance + ) + + def decode(self, x, latent_size: Tuple[int, int] = (64, 64)): + h, w = latent_size + x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2) + x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1) + x = self.ae.decode(x) + return mx.clip(x + 1, 0, 2) * 0.5 + + def generate_images( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + reload_text_encoders: bool = True, + progress: bool = True, + ): + latents = self.generate_latents( + text, n_images, num_steps, guidance, latent_size, seed + ) + mx.eval(next(latents)) + + if reload_text_encoders: + self.reload_text_encoders() + + for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True): + mx.eval(x_t) + + images = [] + for i in tqdm(range(len(x_t)), disable=not progress): + images.append(self.decode(x_t[i : i + 1])) + mx.eval(images[-1]) + images = mx.concatenate(images, axis=0) + mx.eval(images) + + return images + + def training_loss( + self, + x_0: mx.array, + t5_features: mx.array, + clip_features: mx.array, + guidance: mx.array, + ): + # Get the text conditioning + txt = t5_features + txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32) + vec = clip_features + + # Prepare the latent input + x_0, x_ids = self._prepare_latent_images(x_0) + + # Forward process + t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype) + eps = mx.random.normal(x_0.shape, dtype=self.dtype) + x_t = self.sampler.add_noise(x_0, t, noise=eps) + x_t = mx.stop_gradient(x_t) + + # Do the denoising + pred = self.flow( + img=x_t, + img_ids=x_ids, + txt=txt, + txt_ids=txt_ids, + y=vec, + timesteps=t, + guidance=guidance, + ) + + return (pred + x_0 - eps).square().mean() + + def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1): + """Swap the linear layers in the transformer blocks with LoRA layers.""" + all_blocks = self.flow.double_blocks + self.flow.single_blocks + all_blocks.reverse() + num_blocks = num_blocks if num_blocks > 0 else len(all_blocks) + for i, block in zip(range(num_blocks), all_blocks): + loras = [] + for name, module in block.named_modules(): + if isinstance(module, nn.Linear): + loras.append((name, LoRALinear.from_base(module, r=rank))) + block.update_modules(tree_unflatten(loras)) + + def fuse_lora_layers(self): + fused_layers = [] + for name, module in self.flow.named_modules(): + if isinstance(module, LoRALinear): + fused_layers.append((name, module.fuse())) + self.flow.update_modules(tree_unflatten(fused_layers)) diff --git a/flux/flux/autoencoder.py b/flux/flux/autoencoder.py new file mode 100644 index 00000000..6332bb57 --- /dev/null +++ b/flux/flux/autoencoder.py @@ -0,0 +1,357 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import List + +import mlx.core as mx +import mlx.nn as nn +from mlx.nn.layers.upsample import upsample_nearest + + +@dataclass +class AutoEncoderParams: + resolution: int + in_channels: int + ch: int + out_ch: int + ch_mult: List[int] + num_res_blocks: int + z_channels: int + scale_factor: float + shift_factor: float + + +class AttnBlock(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.in_channels = in_channels + + self.norm = nn.GroupNorm( + num_groups=32, + dims=in_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.q = nn.Linear(in_channels, in_channels) + self.k = nn.Linear(in_channels, in_channels) + self.v = nn.Linear(in_channels, in_channels) + self.proj_out = nn.Linear(in_channels, in_channels) + + def __call__(self, x: mx.array) -> mx.array: + B, H, W, C = x.shape + + y = x.reshape(B, 1, -1, C) + y = self.norm(y) + q = self.q(y) + k = self.k(y) + v = self.v(y) + y = mx.fast.scaled_dot_product_attention(q, k, v, scale=C ** (-0.5)) + y = self.proj_out(y) + + return x + y.reshape(B, H, W, C) + + +class ResnetBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + + self.norm1 = nn.GroupNorm( + num_groups=32, + dims=in_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.conv1 = nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + self.norm2 = nn.GroupNorm( + num_groups=32, + dims=out_channels, + eps=1e-6, + affine=True, + pytorch_compatible=True, + ) + self.conv2 = nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + self.nin_shortcut = nn.Linear(in_channels, out_channels) + + def __call__(self, x): + h = x + h = self.norm1(h) + h = nn.silu(h) + h = self.conv1(h) + + h = self.norm2(h) + h = nn.silu(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + x = self.nin_shortcut(x) + + return x + h + + +class Downsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def __call__(self, x: mx.array): + x = mx.pad(x, [(0, 0), (0, 1), (0, 1), (0, 0)]) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels: int): + super().__init__() + self.conv = nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x: mx.array): + x = upsample_nearest(x, (2, 2)) + x = self.conv(x) + return x + + +class Encoder(nn.Module): + def __init__( + self, + resolution: int, + in_channels: int, + ch: int, + ch_mult: list[int], + num_res_blocks: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + # downsampling + self.conv_in = nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = [] + block_in = self.ch + for i_level in range(self.num_resolutions): + block = [] + attn = [] # TODO: Remove the attn, nobody appends anything to it + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + down = {} + down["block"] = block + down["attn"] = attn + if i_level != self.num_resolutions - 1: + down["downsample"] = Downsample(block_in) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid["attn_1"] = AttnBlock(block_in) + self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True + ) + self.conv_out = nn.Conv2d( + block_in, 2 * z_channels, kernel_size=3, stride=1, padding=1 + ) + + def __call__(self, x: mx.array): + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level]["block"][i_block](hs[-1]) + + # TODO: Remove the attn + if len(self.down[i_level]["attn"]) > 0: + h = self.down[i_level]["attn"][i_block](h) + + hs.append(h) + + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level]["downsample"](hs[-1])) + + # middle + h = hs[-1] + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # end + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + + return h + + +class Decoder(nn.Module): + def __init__( + self, + ch: int, + out_ch: int, + ch_mult: list[int], + num_res_blocks: int, + in_channels: int, + resolution: int, + z_channels: int, + ): + super().__init__() + self.ch = ch + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.ffactor = 2 ** (self.num_resolutions - 1) + + # compute in_ch_mult, block_in and curr_res at lowest res + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + + # z to block_in + self.conv_in = nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = {} + self.mid["block_1"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + self.mid["attn_1"] = AttnBlock(block_in) + self.mid["block_2"] = ResnetBlock(in_channels=block_in, out_channels=block_in) + + # upsampling + self.up = [] + for i_level in reversed(range(self.num_resolutions)): + block = [] + attn = [] # TODO: Remove the attn, nobody appends anything to it + + block_out = ch * ch_mult[i_level] + for _ in range(self.num_res_blocks + 1): + block.append(ResnetBlock(in_channels=block_in, out_channels=block_out)) + block_in = block_out + up = {} + up["block"] = block + up["attn"] = attn + if i_level != 0: + up["upsample"] = Upsample(block_in) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = nn.GroupNorm( + num_groups=32, dims=block_in, eps=1e-6, affine=True, pytorch_compatible=True + ) + self.conv_out = nn.Conv2d(block_in, out_ch, kernel_size=3, stride=1, padding=1) + + def __call__(self, z: mx.array): + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid["block_1"](h) + h = self.mid["attn_1"](h) + h = self.mid["block_2"](h) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level]["block"][i_block](h) + + # TODO: Remove the attn + if len(self.up[i_level]["attn"]) > 0: + h = self.up[i_level]["attn"][i_block](h) + + if i_level != 0: + h = self.up[i_level]["upsample"](h) + + # end + h = self.norm_out(h) + h = nn.silu(h) + h = self.conv_out(h) + + return h + + +class DiagonalGaussian(nn.Module): + def __call__(self, z: mx.array): + mean, logvar = mx.split(z, 2, axis=-1) + if self.training: + std = mx.exp(0.5 * logvar) + eps = mx.random.normal(shape=z.shape, dtype=z.dtype) + return mean + std * eps + else: + return mean + + +class AutoEncoder(nn.Module): + def __init__(self, params: AutoEncoderParams): + super().__init__() + self.encoder = Encoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.decoder = Decoder( + resolution=params.resolution, + in_channels=params.in_channels, + ch=params.ch, + out_ch=params.out_ch, + ch_mult=params.ch_mult, + num_res_blocks=params.num_res_blocks, + z_channels=params.z_channels, + ) + self.reg = DiagonalGaussian() + + self.scale_factor = params.scale_factor + self.shift_factor = params.shift_factor + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + if w.ndim == 4: + w = w.transpose(0, 2, 3, 1) + w = w.reshape(-1).reshape(w.shape) + if w.shape[1:3] == (1, 1): + w = w.squeeze((1, 2)) + new_weights[k] = w + return new_weights + + def encode(self, x: mx.array): + z = self.reg(self.encoder(x)) + z = self.scale_factor * (z - self.shift_factor) + return z + + def decode(self, z: mx.array): + z = z / self.scale_factor + self.shift_factor + return self.decoder(z) + + def __call__(self, x: mx.array): + return self.decode(self.encode(x)) diff --git a/flux/flux/clip.py b/flux/flux/clip.py new file mode 100644 index 00000000..d5a30dbf --- /dev/null +++ b/flux/flux/clip.py @@ -0,0 +1,154 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import List, Optional + +import mlx.core as mx +import mlx.nn as nn + +_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu} + + +@dataclass +class CLIPTextModelConfig: + num_layers: int = 23 + model_dims: int = 1024 + num_heads: int = 16 + max_length: int = 77 + vocab_size: int = 49408 + hidden_act: str = "quick_gelu" + + @classmethod + def from_dict(cls, config): + return cls( + num_layers=config["num_hidden_layers"], + model_dims=config["hidden_size"], + num_heads=config["num_attention_heads"], + max_length=config["max_position_embeddings"], + vocab_size=config["vocab_size"], + hidden_act=config["hidden_act"], + ) + + +@dataclass +class CLIPOutput: + # The last_hidden_state indexed at the EOS token and possibly projected if + # the model has a projection layer + pooled_output: Optional[mx.array] = None + + # The full sequence output of the transformer after the final layernorm + last_hidden_state: Optional[mx.array] = None + + # A list of hidden states corresponding to the outputs of the transformer layers + hidden_states: Optional[List[mx.array]] = None + + +class CLIPEncoderLayer(nn.Module): + """The transformer encoder layer from CLIP.""" + + def __init__(self, model_dims: int, num_heads: int, activation: str): + super().__init__() + + self.layer_norm1 = nn.LayerNorm(model_dims) + self.layer_norm2 = nn.LayerNorm(model_dims) + + self.attention = nn.MultiHeadAttention(model_dims, num_heads, bias=True) + + self.linear1 = nn.Linear(model_dims, 4 * model_dims) + self.linear2 = nn.Linear(4 * model_dims, model_dims) + + self.act = _ACTIVATIONS[activation] + + def __call__(self, x, attn_mask=None): + y = self.layer_norm1(x) + y = self.attention(y, y, y, attn_mask) + x = y + x + + y = self.layer_norm2(x) + y = self.linear1(y) + y = self.act(y) + y = self.linear2(y) + x = y + x + + return x + + +class CLIPTextModel(nn.Module): + """Implements the text encoder transformer from CLIP.""" + + def __init__(self, config: CLIPTextModelConfig): + super().__init__() + + self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims) + self.position_embedding = nn.Embedding(config.max_length, config.model_dims) + self.layers = [ + CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act) + for i in range(config.num_layers) + ] + self.final_layer_norm = nn.LayerNorm(config.model_dims) + + def _get_mask(self, N, dtype): + indices = mx.arange(N) + mask = indices[:, None] < indices[None] + mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9) + return mask + + def sanitize(self, weights): + new_weights = {} + for key, w in weights.items(): + # Remove prefixes + if key.startswith("text_model."): + key = key[11:] + if key.startswith("embeddings."): + key = key[11:] + if key.startswith("encoder."): + key = key[8:] + + # Map attention layers + if "self_attn." in key: + key = key.replace("self_attn.", "attention.") + if "q_proj." in key: + key = key.replace("q_proj.", "query_proj.") + if "k_proj." in key: + key = key.replace("k_proj.", "key_proj.") + if "v_proj." in key: + key = key.replace("v_proj.", "value_proj.") + + # Map ffn layers + if "mlp.fc1" in key: + key = key.replace("mlp.fc1", "linear1") + if "mlp.fc2" in key: + key = key.replace("mlp.fc2", "linear2") + + new_weights[key] = w + + return new_weights + + def __call__(self, x): + # Extract some shapes + B, N = x.shape + eos_tokens = x.argmax(-1) + + # Compute the embeddings + x = self.token_embedding(x) + x = x + self.position_embedding.weight[:N] + + # Compute the features from the transformer + mask = self._get_mask(N, x.dtype) + hidden_states = [] + for l in self.layers: + x = l(x, mask) + hidden_states.append(x) + + # Apply the final layernorm and return + x = self.final_layer_norm(x) + last_hidden_state = x + + # Select the EOS token + pooled_output = x[mx.arange(len(x)), eos_tokens] + + return CLIPOutput( + pooled_output=pooled_output, + last_hidden_state=last_hidden_state, + hidden_states=hidden_states, + ) diff --git a/flux/flux/layers.py b/flux/flux/layers.py new file mode 100644 index 00000000..12397904 --- /dev/null +++ b/flux/flux/layers.py @@ -0,0 +1,302 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from functools import partial +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + + +def _rope(pos: mx.array, dim: int, theta: float): + scale = mx.arange(0, dim, 2, dtype=mx.float32) / dim + omega = 1.0 / (theta**scale) + x = pos[..., None] * omega + cosx = mx.cos(x) + sinx = mx.sin(x) + pe = mx.stack([cosx, -sinx, sinx, cosx], axis=-1) + pe = pe.reshape(*pe.shape[:-1], 2, 2) + + return pe + + +@partial(mx.compile, shapeless=True) +def _ab_plus_cd(a, b, c, d): + return a * b + c * d + + +def _apply_rope(x, pe): + s = x.shape + x = x.reshape(*s[:-1], -1, 1, 2) + x = _ab_plus_cd(x[..., 0], pe[..., 0], x[..., 1], pe[..., 1]) + return x.reshape(s) + + +def _attention(q: mx.array, k: mx.array, v: mx.array, pe: mx.array): + B, H, L, D = q.shape + + q = _apply_rope(q, pe) + k = _apply_rope(k, pe) + x = mx.fast.scaled_dot_product_attention(q, k, v, scale=D ** (-0.5)) + + return x.transpose(0, 2, 1, 3).reshape(B, L, -1) + + +def timestep_embedding( + t: mx.array, dim: int, max_period: int = 10000, time_factor: float = 1000.0 +): + half = dim // 2 + freqs = mx.arange(0, half, dtype=mx.float32) / half + freqs = freqs * (-math.log(max_period)) + freqs = mx.exp(freqs) + + x = (time_factor * t)[:, None] * freqs[None] + x = mx.concatenate([mx.cos(x), mx.sin(x)], axis=-1) + + return x.astype(t.dtype) + + +class EmbedND(nn.Module): + def __init__(self, dim: int, theta: int, axes_dim: List[int]): + super().__init__() + + self.dim = dim + self.theta = theta + self.axes_dim = axes_dim + + def __call__(self, ids: mx.array): + n_axes = ids.shape[-1] + pe = mx.concatenate( + [_rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)], + axis=-3, + ) + + return pe[:, None] + + +class MLPEmbedder(nn.Module): + def __init__(self, in_dim: int, hidden_dim: int): + super().__init__() + self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) + self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) + + def __call__(self, x: mx.array) -> mx.array: + return self.out_layer(nn.silu(self.in_layer(x))) + + +class QKNorm(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.query_norm = nn.RMSNorm(dim) + self.key_norm = nn.RMSNorm(dim) + + def __call__(self, q: mx.array, k: mx.array) -> tuple[mx.array, mx.array]: + return self.query_norm(q), self.key_norm(k) + + +class SelfAttention(nn.Module): + def __init__(self, dim: int, num_heads: int = 8, qkv_bias: bool = False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.norm = QKNorm(head_dim) + self.proj = nn.Linear(dim, dim) + + def __call__(self, x: mx.array, pe: mx.array) -> mx.array: + H = self.num_heads + B, L, _ = x.shape + qkv = self.qkv(x) + q, k, v = mx.split(qkv, 3, axis=-1) + q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + q, k = self.norm(q, k) + x = _attention(q, k, v, pe) + x = self.proj(x) + return x + + +@dataclass +class ModulationOut: + shift: mx.array + scale: mx.array + gate: mx.array + + +class Modulation(nn.Module): + def __init__(self, dim: int, double: bool): + super().__init__() + self.is_double = double + self.multiplier = 6 if double else 3 + self.lin = nn.Linear(dim, self.multiplier * dim, bias=True) + + def __call__(self, x: mx.array) -> Tuple[ModulationOut, Optional[ModulationOut]]: + x = self.lin(nn.silu(x)) + xs = mx.split(x[:, None, :], self.multiplier, axis=-1) + + mod1 = ModulationOut(*xs[:3]) + mod2 = ModulationOut(*xs[3:]) if self.is_double else None + + return mod1, mod2 + + +class DoubleStreamBlock(nn.Module): + def __init__( + self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias: bool = False + ): + super().__init__() + + mlp_hidden_dim = int(hidden_size * mlp_ratio) + self.num_heads = num_heads + self.hidden_size = hidden_size + self.img_mod = Modulation(hidden_size, double=True) + self.img_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.img_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.img_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.img_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approx="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + self.txt_mod = Modulation(hidden_size, double=True) + self.txt_norm1 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.txt_attn = SelfAttention( + dim=hidden_size, num_heads=num_heads, qkv_bias=qkv_bias + ) + + self.txt_norm2 = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.txt_mlp = nn.Sequential( + nn.Linear(hidden_size, mlp_hidden_dim, bias=True), + nn.GELU(approx="tanh"), + nn.Linear(mlp_hidden_dim, hidden_size, bias=True), + ) + + def __call__( + self, img: mx.array, txt: mx.array, vec: mx.array, pe: mx.array + ) -> Tuple[mx.array, mx.array]: + B, L, _ = img.shape + _, S, _ = txt.shape + H = self.num_heads + + img_mod1, img_mod2 = self.img_mod(vec) + txt_mod1, txt_mod2 = self.txt_mod(vec) + + # prepare image for attention + img_modulated = self.img_norm1(img) + img_modulated = (1 + img_mod1.scale) * img_modulated + img_mod1.shift + img_qkv = self.img_attn.qkv(img_modulated) + img_q, img_k, img_v = mx.split(img_qkv, 3, axis=-1) + img_q = img_q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_k = img_k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_v = img_v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + img_q, img_k = self.img_attn.norm(img_q, img_k) + + # prepare txt for attention + txt_modulated = self.txt_norm1(txt) + txt_modulated = (1 + txt_mod1.scale) * txt_modulated + txt_mod1.shift + txt_qkv = self.txt_attn.qkv(txt_modulated) + txt_q, txt_k, txt_v = mx.split(txt_qkv, 3, axis=-1) + txt_q = txt_q.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_k = txt_k.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_v = txt_v.reshape(B, S, H, -1).transpose(0, 2, 1, 3) + txt_q, txt_k = self.txt_attn.norm(txt_q, txt_k) + + # run actual attention + q = mx.concatenate([txt_q, img_q], axis=2) + k = mx.concatenate([txt_k, img_k], axis=2) + v = mx.concatenate([txt_v, img_v], axis=2) + + attn = _attention(q, k, v, pe) + txt_attn, img_attn = mx.split(attn, [S], axis=1) + + # calculate the img bloks + img = img + img_mod1.gate * self.img_attn.proj(img_attn) + img = img + img_mod2.gate * self.img_mlp( + (1 + img_mod2.scale) * self.img_norm2(img) + img_mod2.shift + ) + + # calculate the txt bloks + txt = txt + txt_mod1.gate * self.txt_attn.proj(txt_attn) + txt = txt + txt_mod2.gate * self.txt_mlp( + (1 + txt_mod2.scale) * self.txt_norm2(txt) + txt_mod2.shift + ) + + return img, txt + + +class SingleStreamBlock(nn.Module): + def __init__( + self, + hidden_size: int, + num_heads: int, + mlp_ratio: float = 4.0, + qk_scale: Optional[float] = None, + ): + super().__init__() + self.hidden_dim = hidden_size + self.num_heads = num_heads + head_dim = hidden_size // num_heads + self.scale = qk_scale or head_dim**-0.5 + + self.mlp_hidden_dim = int(hidden_size * mlp_ratio) + # qkv and mlp_in + self.linear1 = nn.Linear(hidden_size, hidden_size * 3 + self.mlp_hidden_dim) + # proj and mlp_out + self.linear2 = nn.Linear(hidden_size + self.mlp_hidden_dim, hidden_size) + + self.norm = QKNorm(head_dim) + + self.hidden_size = hidden_size + self.pre_norm = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + + self.mlp_act = nn.GELU(approx="tanh") + self.modulation = Modulation(hidden_size, double=False) + + def __call__(self, x: mx.array, vec: mx.array, pe: mx.array): + B, L, _ = x.shape + H = self.num_heads + + mod, _ = self.modulation(vec) + x_mod = (1 + mod.scale) * self.pre_norm(x) + mod.shift + + q, k, v, mlp = mx.split( + self.linear1(x_mod), + [self.hidden_size, 2 * self.hidden_size, 3 * self.hidden_size], + axis=-1, + ) + q = q.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + k = k.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + v = v.reshape(B, L, H, -1).transpose(0, 2, 1, 3) + q, k = self.norm(q, k) + + # compute attention + y = _attention(q, k, v, pe) + + # compute activation in mlp stream, cat again and run second linear layer + y = self.linear2(mx.concatenate([y, self.mlp_act(mlp)], axis=2)) + return x + mod.gate * y + + +class LastLayer(nn.Module): + def __init__(self, hidden_size: int, patch_size: int, out_channels: int): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, affine=False, eps=1e-6) + self.linear = nn.Linear( + hidden_size, patch_size * patch_size * out_channels, bias=True + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def __call__(self, x: mx.array, vec: mx.array): + shift, scale = mx.split(self.adaLN_modulation(vec), 2, axis=1) + x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] + x = self.linear(x) + return x diff --git a/flux/flux/lora.py b/flux/flux/lora.py new file mode 100644 index 00000000..b0c8ae56 --- /dev/null +++ b/flux/flux/lora.py @@ -0,0 +1,76 @@ +# Copyright © 2024 Apple Inc. + +import math + +import mlx.core as mx +import mlx.nn as nn + + +class LoRALinear(nn.Module): + @staticmethod + def from_base( + linear: nn.Linear, + r: int = 8, + dropout: float = 0.0, + scale: float = 1.0, + ): + output_dims, input_dims = linear.weight.shape + lora_lin = LoRALinear( + input_dims=input_dims, + output_dims=output_dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def fuse(self): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + dtype = weight.dtype + + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = self.scale * self.lora_b.T + lora_a = self.lora_a.T + fused_linear.weight = weight + (lora_b @ lora_a).astype(dtype) + if bias: + fused_linear.bias = linear.bias + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 1.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, r), + ) + self.lora_b = mx.zeros(shape=(r, output_dims)) + + def __call__(self, x): + y = self.linear(x) + z = (self.dropout(x) @ self.lora_a) @ self.lora_b + return y + (self.scale * z).astype(x.dtype) diff --git a/flux/flux/model.py b/flux/flux/model.py new file mode 100644 index 00000000..18ea70b0 --- /dev/null +++ b/flux/flux/model.py @@ -0,0 +1,134 @@ +# Copyright © 2024 Apple Inc. + +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +import mlx.nn as nn + +from .layers import ( + DoubleStreamBlock, + EmbedND, + LastLayer, + MLPEmbedder, + SingleStreamBlock, + timestep_embedding, +) + + +@dataclass +class FluxParams: + in_channels: int + vec_in_dim: int + context_in_dim: int + hidden_size: int + mlp_ratio: float + num_heads: int + depth: int + depth_single_blocks: int + axes_dim: list[int] + theta: int + qkv_bias: bool + guidance_embed: bool + + +class Flux(nn.Module): + def __init__(self, params: FluxParams): + super().__init__() + + self.params = params + self.in_channels = params.in_channels + self.out_channels = self.in_channels + if params.hidden_size % params.num_heads != 0: + raise ValueError( + f"Hidden size {params.hidden_size} must be divisible by num_heads {params.num_heads}" + ) + pe_dim = params.hidden_size // params.num_heads + if sum(params.axes_dim) != pe_dim: + raise ValueError( + f"Got {params.axes_dim} but expected positional dim {pe_dim}" + ) + self.hidden_size = params.hidden_size + self.num_heads = params.num_heads + self.pe_embedder = EmbedND( + dim=pe_dim, theta=params.theta, axes_dim=params.axes_dim + ) + self.img_in = nn.Linear(self.in_channels, self.hidden_size, bias=True) + self.time_in = MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + self.vector_in = MLPEmbedder(params.vec_in_dim, self.hidden_size) + self.guidance_in = ( + MLPEmbedder(in_dim=256, hidden_dim=self.hidden_size) + if params.guidance_embed + else nn.Identity() + ) + self.txt_in = nn.Linear(params.context_in_dim, self.hidden_size) + + self.double_blocks = [ + DoubleStreamBlock( + self.hidden_size, + self.num_heads, + mlp_ratio=params.mlp_ratio, + qkv_bias=params.qkv_bias, + ) + for _ in range(params.depth) + ] + + self.single_blocks = [ + SingleStreamBlock( + self.hidden_size, self.num_heads, mlp_ratio=params.mlp_ratio + ) + for _ in range(params.depth_single_blocks) + ] + + self.final_layer = LastLayer(self.hidden_size, 1, self.out_channels) + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + if k.endswith(".scale"): + k = k[:-6] + ".weight" + for seq in ["img_mlp", "txt_mlp", "adaLN_modulation"]: + if f".{seq}." in k: + k = k.replace(f".{seq}.", f".{seq}.layers.") + break + new_weights[k] = w + return new_weights + + def __call__( + self, + img: mx.array, + img_ids: mx.array, + txt: mx.array, + txt_ids: mx.array, + timesteps: mx.array, + y: mx.array, + guidance: Optional[mx.array] = None, + ) -> mx.array: + if img.ndim != 3 or txt.ndim != 3: + raise ValueError("Input img and txt tensors must have 3 dimensions.") + + img = self.img_in(img) + vec = self.time_in(timestep_embedding(timesteps, 256)) + if self.params.guidance_embed: + if guidance is None: + raise ValueError( + "Didn't get guidance strength for guidance distilled model." + ) + vec = vec + self.guidance_in(timestep_embedding(guidance, 256)) + vec = vec + self.vector_in(y) + txt = self.txt_in(txt) + + ids = mx.concatenate([txt_ids, img_ids], axis=1) + pe = self.pe_embedder(ids).astype(img.dtype) + + for block in self.double_blocks: + img, txt = block(img=img, txt=txt, vec=vec, pe=pe) + + img = mx.concatenate([txt, img], axis=1) + for block in self.single_blocks: + img = block(img, vec=vec, pe=pe) + img = img[:, txt.shape[1] :, ...] + + img = self.final_layer(img, vec) + + return img diff --git a/flux/flux/sampler.py b/flux/flux/sampler.py new file mode 100644 index 00000000..3bff1ca2 --- /dev/null +++ b/flux/flux/sampler.py @@ -0,0 +1,56 @@ +# Copyright © 2024 Apple Inc. + +import math +from functools import lru_cache + +import mlx.core as mx + + +class FluxSampler: + def __init__(self, name: str, base_shift: float = 0.5, max_shift: float = 1.5): + self._base_shift = base_shift + self._max_shift = max_shift + self._schnell = "schnell" in name + + def _time_shift(self, x, t): + x1, x2 = 256, 4096 + t1, t2 = self._base_shift, self._max_shift + exp_mu = math.exp((x - x1) * (t2 - t1) / (x2 - x1) + t1) + t = exp_mu / (exp_mu + (1 / t - 1)) + return t + + @lru_cache + def timesteps( + self, num_steps, image_sequence_length, start: float = 1, stop: float = 0 + ): + t = mx.linspace(start, stop, num_steps + 1) + + if self._schnell: + t = self._time_shift(image_sequence_length, t) + + return t.tolist() + + def random_timesteps(self, B, L, dtype=mx.float32, key=None): + if self._schnell: + # TODO: Should we upweigh 1 and 0.75? + t = mx.random.randint(1, 5, shape=(B,), key=key) + t = t.astype(dtype) / 4 + else: + t = mx.random.uniform(shape=(B,), dtype=dtype, key=key) + t = self._time_shift(L, t) + + return t + + def sample_prior(self, shape, dtype=mx.float32, key=None): + return mx.random.normal(shape, dtype=dtype, key=key) + + def add_noise(self, x, t, noise=None, key=None): + noise = ( + noise + if noise is not None + else mx.random.normal(x.shape, dtype=x.dtype, key=key) + ) + return x * (1 - t) + t * noise + + def step(self, pred, x_t, t, t_prev): + return x_t + (t_prev - t) * pred diff --git a/flux/flux/t5.py b/flux/flux/t5.py new file mode 100644 index 00000000..cf0515cd --- /dev/null +++ b/flux/flux/t5.py @@ -0,0 +1,244 @@ +# Copyright © 2024 Apple Inc. + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple + +import mlx.core as mx +import mlx.nn as nn + +_SHARED_REPLACEMENT_PATTERNS = [ + (".block.", ".layers."), + (".k.", ".key_proj."), + (".o.", ".out_proj."), + (".q.", ".query_proj."), + (".v.", ".value_proj."), + ("shared.", "wte."), + ("lm_head.", "lm_head.linear."), + (".layer.0.layer_norm.", ".ln1."), + (".layer.1.layer_norm.", ".ln2."), + (".layer.2.layer_norm.", ".ln3."), + (".final_layer_norm.", ".ln."), + ( + "layers.0.layer.0.SelfAttention.relative_attention_bias.", + "relative_attention_bias.embeddings.", + ), +] + +_ENCODER_REPLACEMENT_PATTERNS = [ + (".layer.0.SelfAttention.", ".attention."), + (".layer.1.DenseReluDense.", ".dense."), +] + + +@dataclass +class T5Config: + vocab_size: int + num_layers: int + num_heads: int + relative_attention_num_buckets: int + d_kv: int + d_model: int + feed_forward_proj: str + tie_word_embeddings: bool + + d_ff: Optional[int] = None + num_decoder_layers: Optional[int] = None + relative_attention_max_distance: int = 128 + layer_norm_epsilon: float = 1e-6 + + @classmethod + def from_dict(cls, config): + return cls( + vocab_size=config["vocab_size"], + num_layers=config["num_layers"], + num_heads=config["num_heads"], + relative_attention_num_buckets=config["relative_attention_num_buckets"], + d_kv=config["d_kv"], + d_model=config["d_model"], + feed_forward_proj=config["feed_forward_proj"], + tie_word_embeddings=config["tie_word_embeddings"], + d_ff=config.get("d_ff", 4 * config["d_model"]), + num_decoder_layers=config.get("num_decoder_layers", config["num_layers"]), + relative_attention_max_distance=config.get( + "relative_attention_max_distance", 128 + ), + layer_norm_epsilon=config.get("layer_norm_epsilon", 1e-6), + ) + + +class RelativePositionBias(nn.Module): + def __init__(self, config: T5Config, bidirectional: bool): + self.bidirectional = bidirectional + self.num_buckets = config.relative_attention_num_buckets + self.max_distance = config.relative_attention_max_distance + self.n_heads = config.num_heads + self.embeddings = nn.Embedding(self.num_buckets, self.n_heads) + + @staticmethod + def _relative_position_bucket(rpos, bidirectional, num_buckets, max_distance): + num_buckets = num_buckets // 2 if bidirectional else num_buckets + max_exact = num_buckets // 2 + + abspos = rpos.abs() + is_small = abspos < max_exact + + scale = (num_buckets - max_exact) / math.log(max_distance / max_exact) + buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16) + buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1) + + buckets = mx.where(is_small, abspos, buckets_large) + if bidirectional: + buckets = buckets + (rpos > 0) * num_buckets + else: + buckets = buckets * (rpos < 0) + + return buckets + + def __call__(self, query_length: int, key_length: int, offset: int = 0): + """Compute binned relative position bias""" + context_position = mx.arange(offset, query_length)[:, None] + memory_position = mx.arange(key_length)[None, :] + + # shape (query_length, key_length) + relative_position = memory_position - context_position + relative_position_bucket = self._relative_position_bucket( + relative_position, + bidirectional=self.bidirectional, + num_buckets=self.num_buckets, + max_distance=self.max_distance, + ) + + # shape (query_length, key_length, num_heads) + values = self.embeddings(relative_position_bucket) + + # shape (num_heads, query_length, key_length) + return values.transpose(2, 0, 1) + + +class MultiHeadAttention(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + inner_dim = config.d_kv * config.num_heads + self.num_heads = config.num_heads + self.query_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.key_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.value_proj = nn.Linear(config.d_model, inner_dim, bias=False) + self.out_proj = nn.Linear(inner_dim, config.d_model, bias=False) + + def __call__( + self, + queries: mx.array, + keys: mx.array, + values: mx.array, + mask: Optional[mx.array], + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> [mx.array, Tuple[mx.array, mx.array]]: + queries = self.query_proj(queries) + keys = self.key_proj(keys) + values = self.value_proj(values) + + num_heads = self.num_heads + B, L, _ = queries.shape + _, S, _ = keys.shape + queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + keys = mx.concatenate([key_cache, keys], axis=3) + values = mx.concatenate([value_cache, values], axis=2) + + values_hat = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype) + ) + values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1) + + return self.out_proj(values_hat), (keys, values) + + +class DenseActivation(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + mlp_dims = config.d_ff or config.d_model * 4 + self.gated = config.feed_forward_proj.startswith("gated") + if self.gated: + self.wi_0 = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wi_1 = nn.Linear(config.d_model, mlp_dims, bias=False) + else: + self.wi = nn.Linear(config.d_model, mlp_dims, bias=False) + self.wo = nn.Linear(mlp_dims, config.d_model, bias=False) + activation = config.feed_forward_proj.removeprefix("gated-") + if activation == "relu": + self.act = nn.relu + elif activation == "gelu": + self.act = nn.gelu + elif activation == "silu": + self.act = nn.silu + else: + raise ValueError(f"Unknown activation: {activation}") + + def __call__(self, x): + if self.gated: + hidden_act = self.act(self.wi_0(x)) + hidden_linear = self.wi_1(x) + x = hidden_act * hidden_linear + else: + x = self.act(self.wi(x)) + return self.wo(x) + + +class TransformerEncoderLayer(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.attention = MultiHeadAttention(config) + self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.dense = DenseActivation(config) + + def __call__(self, x, mask): + y = self.ln1(x) + y, _ = self.attention(y, y, y, mask=mask) + x = x + y + + y = self.ln2(x) + y = self.dense(y) + return x + y + + +class TransformerEncoder(nn.Module): + def __init__(self, config: T5Config): + super().__init__() + self.layers = [ + TransformerEncoderLayer(config) for i in range(config.num_layers) + ] + self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) + self.relative_attention_bias = RelativePositionBias(config, bidirectional=True) + + def __call__(self, x: mx.array): + pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) + pos_bias = pos_bias.astype(x.dtype) + for layer in self.layers: + x = layer(x, mask=pos_bias) + return self.ln(x) + + +class T5Encoder(nn.Module): + def __init__(self, config: T5Config): + self.wte = nn.Embedding(config.vocab_size, config.d_model) + self.encoder = TransformerEncoder(config) + + def sanitize(self, weights): + new_weights = {} + for k, w in weights.items(): + for old, new in _SHARED_REPLACEMENT_PATTERNS: + k = k.replace(old, new) + if k.startswith("encoder."): + for old, new in _ENCODER_REPLACEMENT_PATTERNS: + k = k.replace(old, new) + new_weights[k] = w + return new_weights + + def __call__(self, inputs: mx.array): + return self.encoder(self.wte(inputs)) diff --git a/flux/flux/tokenizers.py b/flux/flux/tokenizers.py new file mode 100644 index 00000000..796ef389 --- /dev/null +++ b/flux/flux/tokenizers.py @@ -0,0 +1,185 @@ +# Copyright © 2024 Apple Inc. + +import mlx.core as mx +import regex +from sentencepiece import SentencePieceProcessor + + +class CLIPTokenizer: + """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" + + def __init__(self, bpe_ranks, vocab, max_length=77): + self.max_length = max_length + self.bpe_ranks = bpe_ranks + self.vocab = vocab + self.pat = regex.compile( + r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + regex.IGNORECASE, + ) + + self._cache = {self.bos: self.bos, self.eos: self.eos} + + @property + def bos(self): + return "<|startoftext|>" + + @property + def bos_token(self): + return self.vocab[self.bos] + + @property + def eos(self): + return "<|endoftext|>" + + @property + def eos_token(self): + return self.vocab[self.eos] + + def bpe(self, text): + if text in self._cache: + return self._cache[text] + + unigrams = list(text[:-1]) + [text[-1] + ""] + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + if not unique_bigrams: + return unigrams + + # In every iteration try to merge the two most likely bigrams. If none + # was merged we are done. + # + # Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py + while unique_bigrams: + bigram = min( + unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")) + ) + if bigram not in self.bpe_ranks: + break + + new_unigrams = [] + skip = False + for a, b in zip(unigrams, unigrams[1:]): + if skip: + skip = False + continue + + if (a, b) == bigram: + new_unigrams.append(a + b) + skip = True + + else: + new_unigrams.append(a) + + if not skip: + new_unigrams.append(b) + + unigrams = new_unigrams + unique_bigrams = set(zip(unigrams, unigrams[1:])) + + self._cache[text] = unigrams + + return unigrams + + def tokenize(self, text, prepend_bos=True, append_eos=True): + if isinstance(text, list): + return [self.tokenize(t, prepend_bos, append_eos) for t in text] + + # Lower case cleanup and split according to self.pat. Hugging Face does + # a much more thorough job here but this should suffice for 95% of + # cases. + clean_text = regex.sub(r"\s+", " ", text.lower()) + tokens = regex.findall(self.pat, clean_text) + + # Split the tokens according to the byte-pair merge file + bpe_tokens = [ti for t in tokens for ti in self.bpe(t)] + + # Map to token ids and return + tokens = [self.vocab[t] for t in bpe_tokens] + if prepend_bos: + tokens = [self.bos_token] + tokens + if append_eos: + tokens.append(self.eos_token) + + if len(tokens) > self.max_length: + tokens = tokens[: self.max_length] + if append_eos: + tokens[-1] = self.eos_token + + return tokens + + def encode(self, text): + if not isinstance(text, list): + return self.encode([text]) + + tokens = self.tokenize(text) + length = max(len(t) for t in tokens) + for t in tokens: + t.extend([self.eos_token] * (length - len(t))) + + return mx.array(tokens) + + +class T5Tokenizer: + def __init__(self, model_file, max_length=512): + self._tokenizer = SentencePieceProcessor(model_file) + self.max_length = max_length + + @property + def pad(self): + try: + return self._tokenizer.id_to_piece(self.pad_token) + except IndexError: + return None + + @property + def pad_token(self): + return self._tokenizer.pad_id() + + @property + def bos(self): + try: + return self._tokenizer.id_to_piece(self.bos_token) + except IndexError: + return None + + @property + def bos_token(self): + return self._tokenizer.bos_id() + + @property + def eos(self): + try: + return self._tokenizer.id_to_piece(self.eos_token) + except IndexError: + return None + + @property + def eos_token(self): + return self._tokenizer.eos_id() + + def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True): + if isinstance(text, list): + return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text] + + tokens = self._tokenizer.encode(text) + + if prepend_bos and self.bos_token >= 0: + tokens = [self.bos_token] + tokens + if append_eos and self.eos_token >= 0: + tokens.append(self.eos_token) + if pad and len(tokens) < self.max_length and self.pad_token >= 0: + tokens += [self.pad_token] * (self.max_length - len(tokens)) + + return tokens + + def encode(self, text, pad=True): + if not isinstance(text, list): + return self.encode([text], pad=pad) + + pad_token = self.pad_token if self.pad_token >= 0 else 0 + tokens = self.tokenize(text, pad=pad) + length = max(len(t) for t in tokens) + for t in tokens: + t.extend([pad_token] * (length - len(t))) + + return mx.array(tokens) diff --git a/flux/flux/utils.py b/flux/flux/utils.py new file mode 100644 index 00000000..21db17d3 --- /dev/null +++ b/flux/flux/utils.py @@ -0,0 +1,209 @@ +# Copyright © 2024 Apple Inc. + +import json +import os +from dataclasses import dataclass +from typing import Optional + +import mlx.core as mx +from huggingface_hub import hf_hub_download + +from .autoencoder import AutoEncoder, AutoEncoderParams +from .clip import CLIPTextModel, CLIPTextModelConfig +from .model import Flux, FluxParams +from .t5 import T5Config, T5Encoder +from .tokenizers import CLIPTokenizer, T5Tokenizer + + +@dataclass +class ModelSpec: + params: FluxParams + ae_params: AutoEncoderParams + ckpt_path: Optional[str] + ae_path: Optional[str] + repo_id: Optional[str] + repo_flow: Optional[str] + repo_ae: Optional[str] + + +configs = { + "flux-dev": ModelSpec( + repo_id="black-forest-labs/FLUX.1-dev", + repo_flow="flux1-dev.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_DEV"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=True, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), + "flux-schnell": ModelSpec( + repo_id="black-forest-labs/FLUX.1-schnell", + repo_flow="flux1-schnell.safetensors", + repo_ae="ae.safetensors", + ckpt_path=os.getenv("FLUX_SCHNELL"), + params=FluxParams( + in_channels=64, + vec_in_dim=768, + context_in_dim=4096, + hidden_size=3072, + mlp_ratio=4.0, + num_heads=24, + depth=19, + depth_single_blocks=38, + axes_dim=[16, 56, 56], + theta=10_000, + qkv_bias=True, + guidance_embed=False, + ), + ae_path=os.getenv("AE"), + ae_params=AutoEncoderParams( + resolution=256, + in_channels=3, + ch=128, + out_ch=3, + ch_mult=[1, 2, 4, 4], + num_res_blocks=2, + z_channels=16, + scale_factor=0.3611, + shift_factor=0.1159, + ), + ), +} + + +def load_flow_model(name: str, hf_download: bool = True): + # Get the safetensors file to load + ckpt_path = configs[name].ckpt_path + + # Download if needed + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_flow is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow) + + # Make the model + model = Flux(configs[name].params) + + # Load the checkpoint if needed + if ckpt_path is not None: + weights = mx.load(ckpt_path) + weights = model.sanitize(weights) + model.load_weights(list(weights.items())) + + return model + + +def load_ae(name: str, hf_download: bool = True): + # Get the safetensors file to load + ckpt_path = configs[name].ae_path + + # Download if needed + if ( + ckpt_path is None + and configs[name].repo_id is not None + and configs[name].repo_ae is not None + and hf_download + ): + ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_ae) + + # Make the autoencoder + ae = AutoEncoder(configs[name].ae_params) + + # Load the checkpoint if needed + if ckpt_path is not None: + weights = mx.load(ckpt_path) + weights = ae.sanitize(weights) + ae.load_weights(list(weights.items())) + + return ae + + +def load_clip(name: str): + # Load the config + config_path = hf_hub_download(configs[name].repo_id, "text_encoder/config.json") + with open(config_path) as f: + config = CLIPTextModelConfig.from_dict(json.load(f)) + + # Make the clip text encoder + clip = CLIPTextModel(config) + + # Load the weights + ckpt_path = hf_hub_download(configs[name].repo_id, "text_encoder/model.safetensors") + weights = mx.load(ckpt_path) + weights = clip.sanitize(weights) + clip.load_weights(list(weights.items())) + + return clip + + +def load_t5(name: str): + # Load the config + config_path = hf_hub_download(configs[name].repo_id, "text_encoder_2/config.json") + with open(config_path) as f: + config = T5Config.from_dict(json.load(f)) + + # Make the T5 model + t5 = T5Encoder(config) + + # Load the weights + model_index = hf_hub_download( + configs[name].repo_id, "text_encoder_2/model.safetensors.index.json" + ) + weight_files = set() + with open(model_index) as f: + for _, w in json.load(f)["weight_map"].items(): + weight_files.add(w) + weights = {} + for w in weight_files: + w = f"text_encoder_2/{w}" + w = hf_hub_download(configs[name].repo_id, w) + weights.update(mx.load(w)) + weights = t5.sanitize(weights) + t5.load_weights(list(weights.items())) + + return t5 + + +def load_clip_tokenizer(name: str): + vocab_file = hf_hub_download(configs[name].repo_id, "tokenizer/vocab.json") + with open(vocab_file, encoding="utf-8") as f: + vocab = json.load(f) + + merges_file = hf_hub_download(configs[name].repo_id, "tokenizer/merges.txt") + with open(merges_file, encoding="utf-8") as f: + bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1] + bpe_merges = [tuple(m.split()) for m in bpe_merges] + bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) + + return CLIPTokenizer(bpe_ranks, vocab, max_length=77) + + +def load_t5_tokenizer(name: str, pad: bool = True): + model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model") + return T5Tokenizer(model_file, 256 if "schnell" in name else 512) diff --git a/flux/requirements.txt b/flux/requirements.txt new file mode 100644 index 00000000..792205c9 --- /dev/null +++ b/flux/requirements.txt @@ -0,0 +1,7 @@ +mlx>=0.18.1 +huggingface-hub +regex +numpy +tqdm +Pillow +sentencepiece diff --git a/flux/static/dog-r4-g8-1200-512x1024.png b/flux/static/dog-r4-g8-1200-512x1024.png new file mode 100644 index 00000000..7b1ca0e6 Binary files /dev/null and b/flux/static/dog-r4-g8-1200-512x1024.png differ diff --git a/flux/static/dog-r4-g8-1200.png b/flux/static/dog-r4-g8-1200.png new file mode 100644 index 00000000..90e47333 Binary files /dev/null and b/flux/static/dog-r4-g8-1200.png differ diff --git a/flux/static/dog6.png b/flux/static/dog6.png new file mode 100644 index 00000000..2bcf7b8c Binary files /dev/null and b/flux/static/dog6.png differ diff --git a/flux/static/generated-mlx.png b/flux/static/generated-mlx.png new file mode 100644 index 00000000..5c274ef4 Binary files /dev/null and b/flux/static/generated-mlx.png differ diff --git a/flux/txt2image.py b/flux/txt2image.py new file mode 100644 index 00000000..bf2f7294 --- /dev/null +++ b/flux/txt2image.py @@ -0,0 +1,150 @@ +# Copyright © 2024 Apple Inc. + +import argparse + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from PIL import Image +from tqdm import tqdm + +from flux import FluxPipeline + + +def to_latent_size(image_size): + h, w = image_size + h = ((h + 15) // 16) * 16 + w = ((w + 15) // 16) * 16 + + if (h, w) != image_size: + print( + "Warning: The image dimensions need to be divisible by 16px. " + f"Changing size to {h}x{w}." + ) + + return (h // 8, w // 8) + + +def quantization_predicate(name, m): + return hasattr(m, "to_quantized") and m.weight.shape[1] % 512 == 0 + + +def load_adapter(flux, adapter_file, fuse=False): + weights, lora_config = mx.load(adapter_file, return_metadata=True) + rank = int(lora_config["lora_rank"]) + num_blocks = int(lora_config["lora_blocks"]) + flux.linear_to_lora_layers(rank, num_blocks) + flux.flow.load_weights(list(weights.items()), strict=False) + if fuse: + flux.fuse_lora_layers() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Generate images from a textual prompt using stable diffusion" + ) + parser.add_argument("prompt") + parser.add_argument("--model", choices=["schnell", "dev"], default="schnell") + parser.add_argument("--n-images", type=int, default=4) + parser.add_argument( + "--image-size", type=lambda x: tuple(map(int, x.split("x"))), default=(512, 512) + ) + parser.add_argument("--steps", type=int) + parser.add_argument("--guidance", type=float, default=4.0) + parser.add_argument("--n-rows", type=int, default=1) + parser.add_argument("--decoding-batch-size", type=int, default=1) + parser.add_argument("--quantize", "-q", action="store_true") + parser.add_argument("--preload-models", action="store_true") + parser.add_argument("--output", default="out.png") + parser.add_argument("--save-raw", action="store_true") + parser.add_argument("--seed", type=int) + parser.add_argument("--verbose", "-v", action="store_true") + parser.add_argument("--adapter") + parser.add_argument("--fuse-adapter", action="store_true") + parser.add_argument("--no-t5-padding", dest="t5_padding", action="store_false") + args = parser.parse_args() + + # Load the models + flux = FluxPipeline("flux-" + args.model, t5_padding=args.t5_padding) + args.steps = args.steps or (50 if args.model == "dev" else 2) + + if args.adapter: + load_adapter(flux, args.adapter, fuse=args.fuse_adapter) + + if args.quantize: + nn.quantize(flux.flow, class_predicate=quantization_predicate) + nn.quantize(flux.t5, class_predicate=quantization_predicate) + nn.quantize(flux.clip, class_predicate=quantization_predicate) + + if args.preload_models: + sd.ensure_models_are_loaded() + + # Make the generator + latent_size = to_latent_size(args.image_size) + latents = flux.generate_latents( + args.prompt, + n_images=args.n_images, + num_steps=args.steps, + latent_size=latent_size, + guidance=args.guidance, + seed=args.seed, + ) + + # First we get and eval the conditioning + conditioning = next(latents) + mx.eval(conditioning) + peak_mem_conditioning = mx.metal.get_peak_memory() / 1024**3 + mx.metal.reset_peak_memory() + + # The following is not necessary but it may help in memory constrained + # systems by reusing the memory kept by the text encoders. + del flux.t5 + del flux.clip + + # Actual denoising loop + for x_t in tqdm(latents, total=args.steps): + mx.eval(x_t) + + # The following is not necessary but it may help in memory constrained + # systems by reusing the memory kept by the flow transformer. + del flux.flow + peak_mem_generation = mx.metal.get_peak_memory() / 1024**3 + mx.metal.reset_peak_memory() + + # Decode them into images + decoded = [] + for i in tqdm(range(0, args.n_images, args.decoding_batch_size)): + decoded.append(flux.decode(x_t[i : i + args.decoding_batch_size], latent_size)) + mx.eval(decoded[-1]) + peak_mem_decoding = mx.metal.get_peak_memory() / 1024**3 + peak_mem_overall = max( + peak_mem_conditioning, peak_mem_generation, peak_mem_decoding + ) + + if args.save_raw: + *name, suffix = args.output.split(".") + name = ".".join(name) + x = mx.concatenate(decoded, axis=0) + x = (x * 255).astype(mx.uint8) + for i in range(len(x)): + im = Image.fromarray(np.array(x[i])) + im.save(".".join([name, str(i), suffix])) + else: + # Arrange them on a grid + x = mx.concatenate(decoded, axis=0) + x = mx.pad(x, [(0, 0), (4, 4), (4, 4), (0, 0)]) + B, H, W, C = x.shape + x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4) + x = x.reshape(args.n_rows * H, B // args.n_rows * W, C) + x = (x * 255).astype(mx.uint8) + + # Save them to disc + im = Image.fromarray(np.array(x)) + im.save(args.output) + + # Report the peak memory used during generation + if args.verbose: + print(f"Peak memory used for the text: {peak_mem_conditioning:.3f}GB") + print(f"Peak memory used for the generation: {peak_mem_generation:.3f}GB") + print(f"Peak memory used for the decoding: {peak_mem_decoding:.3f}GB") + print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")