diff --git a/flux/flux/__init__.py b/flux/flux/__init__.py index e69de29b..db5f4625 100644 --- a/flux/flux/__init__.py +++ b/flux/flux/__init__.py @@ -0,0 +1,145 @@ +import math +import time +from typing import Tuple + +import mlx.core as mx +from tqdm import tqdm + +from .utils import ( + load_ae, + load_clip, + load_clip_tokenizer, + load_flow_model, + load_t5, + load_t5_tokenizer, +) + + +class FluxPipeline: + def __init__(self, name: str): + self.name = name + self.ae = load_ae(name) + self.flow = load_flow_model(name) + self.clip = load_clip(name) + self.clip_tokenizer = load_clip_tokenizer(name) + self.t5 = load_t5(name) + self.t5_tokenizer = load_t5_tokenizer(name) + self.dtype = mx.bfloat16 + + def ensure_models_are_loaded(self): + mx.eval( + self.ae.parameters(), + self.flow.parameters(), + self.clip.parameters(), + self.t5.parameters(), + ) + + def _prior(self, n_images: int = 1, latent_size: Tuple[int, int] = (64, 64)): + return mx.random.normal( + shape=(n_images, *latent_size, 16), + dtype=self.dtype, + ) + + def _prepare(self, x, text): + b, h, w, c = x.shape + + # Prepare the latent image input and its ids for positional encoding + x = x.reshape(b, h // 2, 2, w // 2, 2, c) + x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4) + x_ids = mx.concatenate( + [ + mx.zeros((h // 2, w // 2, 1), dtype=mx.int32), + mx.broadcast_to(mx.arange(h // 2)[:, None, None], (h // 2, w // 2, 1)), + mx.broadcast_to(mx.arange(w // 2)[None, :, None], (h // 2, w // 2, 1)), + ], + axis=-1, + ) + x_ids = mx.broadcast_to(x_ids.reshape(1, h * w // 4, 3), (b, h * w // 4, 3)) + + # Prepare the text features + t5_tokens = mx.array([self.t5_tokenizer.tokenize(text)]) + txt = self.t5(t5_tokens) + txt = mx.broadcast_to(txt, (b, *txt.shape[1:])) + txt_ids = mx.zeros((b, txt.shape[1], 3), dtype=mx.int32) + + # Prepare the clip text features + clip_tokens = mx.array([self.clip_tokenizer.tokenize(text)]) + vec = self.clip(clip_tokens).pooled_output + vec = mx.broadcast_to(vec, (b, *vec.shape[1:])) + + return { + "img": x, + "img_ids": x_ids, + "txt": txt, + "txt_ids": txt_ids, + "vec": vec, + } + + def _get_shedule( + self, + num_steps, + image_seq_len, + base_shift: float = 0.5, + max_shift: float = 1.5, + shift: bool = True, + ): + timesteps = mx.linspace(1, 0, num_steps + 1) + + if shift: + x = image_seq_len + x1, x2 = 256, 4096 + y1, y2 = base_shift, max_shift + mu = (x - x1) * (y2 - y1) / (x2 - x1) + y1 + timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1)) + + return timesteps + + def generate_latents( + self, + text: str, + n_images: int = 1, + num_steps: int = 35, + guidance: float = 4.0, + latent_size: Tuple[int, int] = (64, 64), + seed=None, + ): + # Set the PRNG state + seed = int(time.time()) if seed is None else seed + mx.random.seed(seed) + + # Create the latent variables + x_T = self._prior(n_images, latent_size) + + # Get the initial inputs + inputs = self._prepare(x_T, text) + + # Perform the denoising loop + mx.eval(inputs) + timesteps = self._get_shedule( + num_steps, x_T.shape[1], shift="schnell" not in self.name + ) + timesteps = timesteps.tolist() + guidance = mx.full((n_images,), guidance, dtype=self.dtype) + for t, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])): + t_arr = mx.full((n_images,), t, dtype=self.dtype) + pred = self.flow( + img=inputs["img"], + img_ids=inputs["img_ids"], + txt=inputs["txt"], + txt_ids=inputs["txt_ids"], + y=inputs["vec"], + timesteps=t_arr, + guidance=guidance, + ) + + inputs["img"] = inputs["img"] + (t_prev - t) * pred + mx.eval(inputs["img"]) + + img = inputs["img"] + h, w = latent_size + img = img.reshape(n_images, h // 2, w // 2, -1, 2, 2) + img = img.transpose(0, 1, 4, 2, 5, 3).reshape(n_images, h, w, -1) + img = self.ae.decode(img) + mx.eval(img) + + return ((mx.clip(img, -1, 1) + 1) * 127.5).astype(mx.uint8) diff --git a/flux/flux/t5.py b/flux/flux/t5.py index 0fe39d0e..396f542d 100644 --- a/flux/flux/t5.py +++ b/flux/flux/t5.py @@ -85,7 +85,7 @@ class RelativePositionBias(nn.Module): buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16) buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1) - buckets = mx.where(is_small, rpos, buckets_large) + buckets = mx.where(is_small, abspos, buckets_large) if bidirectional: buckets = buckets + (rpos > 0) * num_buckets else: @@ -140,7 +140,7 @@ class MultiHeadAttention(nn.Module): B, L, _ = queries.shape _, S, _ = keys.shape queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1) + keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3) if cache is not None: @@ -149,7 +149,7 @@ class MultiHeadAttention(nn.Module): values = mx.concatenate([value_cache, values], axis=2) values_hat = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=1.0 + queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype) ) values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1) @@ -216,81 +216,12 @@ class TransformerEncoder(nn.Module): def __call__(self, x: mx.array): pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1]) + pos_bias = pos_bias.astype(x.dtype) for layer in self.layers: x = layer(x, mask=pos_bias) return self.ln(x) -class TransformerDecoderLayer(nn.Module): - def __init__(self, config: T5Config): - super().__init__() - self.self_attention = MultiHeadAttention(config) - self.cross_attention = MultiHeadAttention(config) - self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.dense = DenseActivation(config) - - def __call__( - self, - x: mx.array, - memory: mx.array, - mask: mx.array, - memory_mask: mx.array, - cache: Optional[List[Tuple[mx.array, mx.array]]] = None, - ): - y = self.ln1(x) - y, cache = self.self_attention(y, y, y, mask, cache) - x = x + y - - y = self.ln2(x) - y, _ = self.cross_attention(y, memory, memory, memory_mask) - x = x + y - - y = self.ln3(x) - y = self.dense(y) - x = x + y - - return x, cache - - -class TransformerDecoder(nn.Module): - def __init__(self, config: T5Config): - super().__init__() - n_layers = getattr(config, "num_decoder_layers", config.num_layers) - self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)] - self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon) - self.relative_attention_bias = RelativePositionBias(config, bidirectional=False) - - def __call__(self, x, memory, mask, memory_mask, cache=None): - if cache is not None: - offset = cache[0][0].shape[3] - else: - offset = 0 - cache = [None] * len(self.layers) - - T = offset + x.shape[1] - pos_bias = self.relative_attention_bias(T, T, offset=offset) - if mask is not None: - mask += pos_bias - else: - mask = pos_bias - - for e, layer in enumerate(self.layers): - x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e]) - x = self.ln(x) - - return x, cache - - -class OutputHead(nn.Module): - def __init__(self, config: T5Config): - self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False) - - def __call__(self, inputs): - return self.linear(inputs) - - class T5Encoder(nn.Module): def __init__(self, config: T5Config): self.wte = nn.Embedding(config.vocab_size, config.d_model) diff --git a/flux/flux/tokenizers.py b/flux/flux/tokenizers.py index 523cef1a..074b5223 100644 --- a/flux/flux/tokenizers.py +++ b/flux/flux/tokenizers.py @@ -5,7 +5,8 @@ from sentencepiece import SentencePieceProcessor class CLIPTokenizer: """A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ .""" - def __init__(self, bpe_ranks, vocab): + def __init__(self, bpe_ranks, vocab, max_length=77): + self.max_length = max_length self.bpe_ranks = bpe_ranks self.vocab = vocab self.pat = regex.compile( @@ -96,6 +97,11 @@ class CLIPTokenizer: if append_eos: tokens.append(self.eos_token) + if len(tokens) > self.max_length: + tokens = tokens[: self.max_length] + if append_eos: + tokens[-1] = self.eos_token + return tokens diff --git a/flux/flux/utils.py b/flux/flux/utils.py index 1671373e..7c8e9214 100644 --- a/flux/flux/utils.py +++ b/flux/flux/utils.py @@ -199,7 +199,7 @@ def load_clip_tokenizer(name: str): bpe_merges = [tuple(m.split()) for m in bpe_merges] bpe_ranks = dict(map(reversed, enumerate(bpe_merges))) - return CLIPTokenizer(bpe_ranks, vocab) + return CLIPTokenizer(bpe_ranks, vocab, max_length=77) def load_t5_tokenizer(name: str):