Bugfix in t5 rpos and initial generation example

This commit is contained in:
Angelos Katharopoulos 2024-09-28 01:09:59 -07:00
parent 88603f0330
commit 070c58ed92
4 changed files with 157 additions and 75 deletions

View File

@ -0,0 +1,145 @@
import math
import time
from typing import Tuple
import mlx.core as mx
from tqdm import tqdm
from .utils import (
load_ae,
load_clip,
load_clip_tokenizer,
load_flow_model,
load_t5,
load_t5_tokenizer,
)
class FluxPipeline:
def __init__(self, name: str):
self.name = name
self.ae = load_ae(name)
self.flow = load_flow_model(name)
self.clip = load_clip(name)
self.clip_tokenizer = load_clip_tokenizer(name)
self.t5 = load_t5(name)
self.t5_tokenizer = load_t5_tokenizer(name)
self.dtype = mx.bfloat16
def ensure_models_are_loaded(self):
mx.eval(
self.ae.parameters(),
self.flow.parameters(),
self.clip.parameters(),
self.t5.parameters(),
)
def _prior(self, n_images: int = 1, latent_size: Tuple[int, int] = (64, 64)):
return mx.random.normal(
shape=(n_images, *latent_size, 16),
dtype=self.dtype,
)
def _prepare(self, x, text):
b, h, w, c = x.shape
# Prepare the latent image input and its ids for positional encoding
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
x_ids = mx.concatenate(
[
mx.zeros((h // 2, w // 2, 1), dtype=mx.int32),
mx.broadcast_to(mx.arange(h // 2)[:, None, None], (h // 2, w // 2, 1)),
mx.broadcast_to(mx.arange(w // 2)[None, :, None], (h // 2, w // 2, 1)),
],
axis=-1,
)
x_ids = mx.broadcast_to(x_ids.reshape(1, h * w // 4, 3), (b, h * w // 4, 3))
# Prepare the text features
t5_tokens = mx.array([self.t5_tokenizer.tokenize(text)])
txt = self.t5(t5_tokens)
txt = mx.broadcast_to(txt, (b, *txt.shape[1:]))
txt_ids = mx.zeros((b, txt.shape[1], 3), dtype=mx.int32)
# Prepare the clip text features
clip_tokens = mx.array([self.clip_tokenizer.tokenize(text)])
vec = self.clip(clip_tokens).pooled_output
vec = mx.broadcast_to(vec, (b, *vec.shape[1:]))
return {
"img": x,
"img_ids": x_ids,
"txt": txt,
"txt_ids": txt_ids,
"vec": vec,
}
def _get_shedule(
self,
num_steps,
image_seq_len,
base_shift: float = 0.5,
max_shift: float = 1.5,
shift: bool = True,
):
timesteps = mx.linspace(1, 0, num_steps + 1)
if shift:
x = image_seq_len
x1, x2 = 256, 4096
y1, y2 = base_shift, max_shift
mu = (x - x1) * (y2 - y1) / (x2 - x1) + y1
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
return timesteps
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 35,
guidance: float = 4.0,
latent_size: Tuple[int, int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = int(time.time()) if seed is None else seed
mx.random.seed(seed)
# Create the latent variables
x_T = self._prior(n_images, latent_size)
# Get the initial inputs
inputs = self._prepare(x_T, text)
# Perform the denoising loop
mx.eval(inputs)
timesteps = self._get_shedule(
num_steps, x_T.shape[1], shift="schnell" not in self.name
)
timesteps = timesteps.tolist()
guidance = mx.full((n_images,), guidance, dtype=self.dtype)
for t, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
t_arr = mx.full((n_images,), t, dtype=self.dtype)
pred = self.flow(
img=inputs["img"],
img_ids=inputs["img_ids"],
txt=inputs["txt"],
txt_ids=inputs["txt_ids"],
y=inputs["vec"],
timesteps=t_arr,
guidance=guidance,
)
inputs["img"] = inputs["img"] + (t_prev - t) * pred
mx.eval(inputs["img"])
img = inputs["img"]
h, w = latent_size
img = img.reshape(n_images, h // 2, w // 2, -1, 2, 2)
img = img.transpose(0, 1, 4, 2, 5, 3).reshape(n_images, h, w, -1)
img = self.ae.decode(img)
mx.eval(img)
return ((mx.clip(img, -1, 1) + 1) * 127.5).astype(mx.uint8)

View File

@ -85,7 +85,7 @@ class RelativePositionBias(nn.Module):
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
buckets = mx.where(is_small, rpos, buckets_large)
buckets = mx.where(is_small, abspos, buckets_large)
if bidirectional:
buckets = buckets + (rpos > 0) * num_buckets
else:
@ -140,7 +140,7 @@ class MultiHeadAttention(nn.Module):
B, L, _ = queries.shape
_, S, _ = keys.shape
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
if cache is not None:
@ -149,7 +149,7 @@ class MultiHeadAttention(nn.Module):
values = mx.concatenate([value_cache, values], axis=2)
values_hat = mx.fast.scaled_dot_product_attention(
queries, keys, values, scale=1.0
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
)
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
@ -216,81 +216,12 @@ class TransformerEncoder(nn.Module):
def __call__(self, x: mx.array):
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
pos_bias = pos_bias.astype(x.dtype)
for layer in self.layers:
x = layer(x, mask=pos_bias)
return self.ln(x)
class TransformerDecoderLayer(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
self.self_attention = MultiHeadAttention(config)
self.cross_attention = MultiHeadAttention(config)
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.dense = DenseActivation(config)
def __call__(
self,
x: mx.array,
memory: mx.array,
mask: mx.array,
memory_mask: mx.array,
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
):
y = self.ln1(x)
y, cache = self.self_attention(y, y, y, mask, cache)
x = x + y
y = self.ln2(x)
y, _ = self.cross_attention(y, memory, memory, memory_mask)
x = x + y
y = self.ln3(x)
y = self.dense(y)
x = x + y
return x, cache
class TransformerDecoder(nn.Module):
def __init__(self, config: T5Config):
super().__init__()
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
def __call__(self, x, memory, mask, memory_mask, cache=None):
if cache is not None:
offset = cache[0][0].shape[3]
else:
offset = 0
cache = [None] * len(self.layers)
T = offset + x.shape[1]
pos_bias = self.relative_attention_bias(T, T, offset=offset)
if mask is not None:
mask += pos_bias
else:
mask = pos_bias
for e, layer in enumerate(self.layers):
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
x = self.ln(x)
return x, cache
class OutputHead(nn.Module):
def __init__(self, config: T5Config):
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
def __call__(self, inputs):
return self.linear(inputs)
class T5Encoder(nn.Module):
def __init__(self, config: T5Config):
self.wte = nn.Embedding(config.vocab_size, config.d_model)

View File

@ -5,7 +5,8 @@ from sentencepiece import SentencePieceProcessor
class CLIPTokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
def __init__(self, bpe_ranks, vocab, max_length=77):
self.max_length = max_length
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
@ -96,6 +97,11 @@ class CLIPTokenizer:
if append_eos:
tokens.append(self.eos_token)
if len(tokens) > self.max_length:
tokens = tokens[: self.max_length]
if append_eos:
tokens[-1] = self.eos_token
return tokens

View File

@ -199,7 +199,7 @@ def load_clip_tokenizer(name: str):
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return CLIPTokenizer(bpe_ranks, vocab)
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
def load_t5_tokenizer(name: str):