mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Bugfix in t5 rpos and initial generation example
This commit is contained in:
parent
88603f0330
commit
070c58ed92
@ -0,0 +1,145 @@
|
||||
import math
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
from tqdm import tqdm
|
||||
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_clip_tokenizer,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipeline:
|
||||
def __init__(self, name: str):
|
||||
self.name = name
|
||||
self.ae = load_ae(name)
|
||||
self.flow = load_flow_model(name)
|
||||
self.clip = load_clip(name)
|
||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||
self.t5 = load_t5(name)
|
||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||
self.dtype = mx.bfloat16
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(
|
||||
self.ae.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.clip.parameters(),
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def _prior(self, n_images: int = 1, latent_size: Tuple[int, int] = (64, 64)):
|
||||
return mx.random.normal(
|
||||
shape=(n_images, *latent_size, 16),
|
||||
dtype=self.dtype,
|
||||
)
|
||||
|
||||
def _prepare(self, x, text):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
# Prepare the latent image input and its ids for positional encoding
|
||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||
x_ids = mx.concatenate(
|
||||
[
|
||||
mx.zeros((h // 2, w // 2, 1), dtype=mx.int32),
|
||||
mx.broadcast_to(mx.arange(h // 2)[:, None, None], (h // 2, w // 2, 1)),
|
||||
mx.broadcast_to(mx.arange(w // 2)[None, :, None], (h // 2, w // 2, 1)),
|
||||
],
|
||||
axis=-1,
|
||||
)
|
||||
x_ids = mx.broadcast_to(x_ids.reshape(1, h * w // 4, 3), (b, h * w // 4, 3))
|
||||
|
||||
# Prepare the text features
|
||||
t5_tokens = mx.array([self.t5_tokenizer.tokenize(text)])
|
||||
txt = self.t5(t5_tokens)
|
||||
txt = mx.broadcast_to(txt, (b, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((b, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
clip_tokens = mx.array([self.clip_tokenizer.tokenize(text)])
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
vec = mx.broadcast_to(vec, (b, *vec.shape[1:]))
|
||||
|
||||
return {
|
||||
"img": x,
|
||||
"img_ids": x_ids,
|
||||
"txt": txt,
|
||||
"txt_ids": txt_ids,
|
||||
"vec": vec,
|
||||
}
|
||||
|
||||
def _get_shedule(
|
||||
self,
|
||||
num_steps,
|
||||
image_seq_len,
|
||||
base_shift: float = 0.5,
|
||||
max_shift: float = 1.5,
|
||||
shift: bool = True,
|
||||
):
|
||||
timesteps = mx.linspace(1, 0, num_steps + 1)
|
||||
|
||||
if shift:
|
||||
x = image_seq_len
|
||||
x1, x2 = 256, 4096
|
||||
y1, y2 = base_shift, max_shift
|
||||
mu = (x - x1) * (y2 - y1) / (x2 - x1) + y1
|
||||
timesteps = math.exp(mu) / (math.exp(mu) + (1 / timesteps - 1))
|
||||
|
||||
return timesteps
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = int(time.time()) if seed is None else seed
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self._prior(n_images, latent_size)
|
||||
|
||||
# Get the initial inputs
|
||||
inputs = self._prepare(x_T, text)
|
||||
|
||||
# Perform the denoising loop
|
||||
mx.eval(inputs)
|
||||
timesteps = self._get_shedule(
|
||||
num_steps, x_T.shape[1], shift="schnell" not in self.name
|
||||
)
|
||||
timesteps = timesteps.tolist()
|
||||
guidance = mx.full((n_images,), guidance, dtype=self.dtype)
|
||||
for t, t_prev in tqdm(zip(timesteps[:-1], timesteps[1:])):
|
||||
t_arr = mx.full((n_images,), t, dtype=self.dtype)
|
||||
pred = self.flow(
|
||||
img=inputs["img"],
|
||||
img_ids=inputs["img_ids"],
|
||||
txt=inputs["txt"],
|
||||
txt_ids=inputs["txt_ids"],
|
||||
y=inputs["vec"],
|
||||
timesteps=t_arr,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
inputs["img"] = inputs["img"] + (t_prev - t) * pred
|
||||
mx.eval(inputs["img"])
|
||||
|
||||
img = inputs["img"]
|
||||
h, w = latent_size
|
||||
img = img.reshape(n_images, h // 2, w // 2, -1, 2, 2)
|
||||
img = img.transpose(0, 1, 4, 2, 5, 3).reshape(n_images, h, w, -1)
|
||||
img = self.ae.decode(img)
|
||||
mx.eval(img)
|
||||
|
||||
return ((mx.clip(img, -1, 1) + 1) * 127.5).astype(mx.uint8)
|
@ -85,7 +85,7 @@ class RelativePositionBias(nn.Module):
|
||||
buckets_large = (mx.log(abspos / max_exact) * scale).astype(mx.int16)
|
||||
buckets_large = mx.minimum(max_exact + buckets_large, num_buckets - 1)
|
||||
|
||||
buckets = mx.where(is_small, rpos, buckets_large)
|
||||
buckets = mx.where(is_small, abspos, buckets_large)
|
||||
if bidirectional:
|
||||
buckets = buckets + (rpos > 0) * num_buckets
|
||||
else:
|
||||
@ -140,7 +140,7 @@ class MultiHeadAttention(nn.Module):
|
||||
B, L, _ = queries.shape
|
||||
_, S, _ = keys.shape
|
||||
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 3, 1)
|
||||
keys = keys.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
values = values.reshape(B, S, num_heads, -1).transpose(0, 2, 1, 3)
|
||||
|
||||
if cache is not None:
|
||||
@ -149,7 +149,7 @@ class MultiHeadAttention(nn.Module):
|
||||
values = mx.concatenate([value_cache, values], axis=2)
|
||||
|
||||
values_hat = mx.fast.scaled_dot_product_attention(
|
||||
queries, keys, values, scale=1.0
|
||||
queries, keys, values, scale=1.0, mask=mask.astype(queries.dtype)
|
||||
)
|
||||
values_hat = values_hat.transpose(0, 2, 1, 3).reshape(B, L, -1)
|
||||
|
||||
@ -216,81 +216,12 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
def __call__(self, x: mx.array):
|
||||
pos_bias = self.relative_attention_bias(x.shape[1], x.shape[1])
|
||||
pos_bias = pos_bias.astype(x.dtype)
|
||||
for layer in self.layers:
|
||||
x = layer(x, mask=pos_bias)
|
||||
return self.ln(x)
|
||||
|
||||
|
||||
class TransformerDecoderLayer(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
self.self_attention = MultiHeadAttention(config)
|
||||
self.cross_attention = MultiHeadAttention(config)
|
||||
self.ln1 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln2 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.ln3 = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.dense = DenseActivation(config)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x: mx.array,
|
||||
memory: mx.array,
|
||||
mask: mx.array,
|
||||
memory_mask: mx.array,
|
||||
cache: Optional[List[Tuple[mx.array, mx.array]]] = None,
|
||||
):
|
||||
y = self.ln1(x)
|
||||
y, cache = self.self_attention(y, y, y, mask, cache)
|
||||
x = x + y
|
||||
|
||||
y = self.ln2(x)
|
||||
y, _ = self.cross_attention(y, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
y = self.ln3(x)
|
||||
y = self.dense(y)
|
||||
x = x + y
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class TransformerDecoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
super().__init__()
|
||||
n_layers = getattr(config, "num_decoder_layers", config.num_layers)
|
||||
self.layers = [TransformerDecoderLayer(config) for i in range(n_layers)]
|
||||
self.ln = nn.RMSNorm(config.d_model, eps=config.layer_norm_epsilon)
|
||||
self.relative_attention_bias = RelativePositionBias(config, bidirectional=False)
|
||||
|
||||
def __call__(self, x, memory, mask, memory_mask, cache=None):
|
||||
if cache is not None:
|
||||
offset = cache[0][0].shape[3]
|
||||
else:
|
||||
offset = 0
|
||||
cache = [None] * len(self.layers)
|
||||
|
||||
T = offset + x.shape[1]
|
||||
pos_bias = self.relative_attention_bias(T, T, offset=offset)
|
||||
if mask is not None:
|
||||
mask += pos_bias
|
||||
else:
|
||||
mask = pos_bias
|
||||
|
||||
for e, layer in enumerate(self.layers):
|
||||
x, cache[e] = layer(x, memory, mask, memory_mask, cache=cache[e])
|
||||
x = self.ln(x)
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
class OutputHead(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.linear = nn.Linear(config.d_model, config.vocab_size, bias=False)
|
||||
|
||||
def __call__(self, inputs):
|
||||
return self.linear(inputs)
|
||||
|
||||
|
||||
class T5Encoder(nn.Module):
|
||||
def __init__(self, config: T5Config):
|
||||
self.wte = nn.Embedding(config.vocab_size, config.d_model)
|
||||
|
@ -5,7 +5,8 @@ from sentencepiece import SentencePieceProcessor
|
||||
class CLIPTokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
|
||||
def __init__(self, bpe_ranks, vocab):
|
||||
def __init__(self, bpe_ranks, vocab, max_length=77):
|
||||
self.max_length = max_length
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
@ -96,6 +97,11 @@ class CLIPTokenizer:
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
if len(tokens) > self.max_length:
|
||||
tokens = tokens[: self.max_length]
|
||||
if append_eos:
|
||||
tokens[-1] = self.eos_token
|
||||
|
||||
return tokens
|
||||
|
||||
|
||||
|
@ -199,7 +199,7 @@ def load_clip_tokenizer(name: str):
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
return CLIPTokenizer(bpe_ranks, vocab)
|
||||
return CLIPTokenizer(bpe_ranks, vocab, max_length=77)
|
||||
|
||||
|
||||
def load_t5_tokenizer(name: str):
|
||||
|
Loading…
Reference in New Issue
Block a user