Stable diffusion XL (#516)

This commit is contained in:
Angelos Katharopoulos
2024-03-08 10:24:19 -08:00
committed by GitHub
parent 8c2cf665ed
commit 3a9e6c3f70
11 changed files with 449 additions and 105 deletions

View File

@@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import time
from typing import Tuple
from typing import Optional, Tuple
import mlx.core as mx
@@ -13,7 +13,7 @@ from .model_io import (
load_tokenizer,
load_unet,
)
from .sampler import SimpleEulerSampler
from .sampler import SimpleEulerAncestralSampler, SimpleEulerSampler
class StableDiffusion:
@@ -22,10 +22,27 @@ class StableDiffusion:
self.diffusion_config = load_diffusion_config(model)
self.unet = load_unet(model, float16)
self.text_encoder = load_text_encoder(model, float16)
self.autoencoder = load_autoencoder(model, float16)
self.autoencoder = load_autoencoder(model, False)
self.sampler = SimpleEulerSampler(self.diffusion_config)
self.tokenizer = load_tokenizer(model)
def ensure_models_are_loaded(self):
mx.eval(self.unet.parameters())
mx.eval(self.text_encoder.parameters())
mx.eval(self.autoencoder.parameters())
def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None):
# Tokenize the text
tokens = [tokenizer.tokenize(text)]
if negative_text is not None:
tokens += [tokenizer.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
tokens = mx.array(tokens)
return tokens
def _get_text_conditioning(
self,
text: str,
@@ -34,16 +51,12 @@ class StableDiffusion:
negative_text: str = "",
):
# Tokenize the text
tokens = [self.tokenizer.tokenize(text)]
if cfg_weight > 1:
tokens += [self.tokenizer.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
tokens = mx.array(tokens)
tokens = self._tokenize(
self.tokenizer, text, (negative_text if cfg_weight > 1 else None)
)
# Compute the features
conditioning = self.text_encoder(tokens)
conditioning = self.text_encoder(tokens).last_hidden_state
# Repeat the conditioning for each of the generated images
if n_images > 1:
@@ -51,10 +64,14 @@ class StableDiffusion:
return conditioning
def _denoising_step(self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5):
def _denoising_step(
self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5, text_time=None
):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
eps_pred = self.unet(
x_t_unet, t_unet, encoder_x=conditioning, text_time=text_time
)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
@@ -65,13 +82,21 @@ class StableDiffusion:
return x_t_prev
def _denoising_loop(
self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5
self,
x_T,
T,
conditioning,
num_steps: int = 50,
cfg_weight: float = 7.5,
text_time=None,
):
x_t = x_T
for t, t_prev in self.sampler.timesteps(
num_steps, start_time=T, dtype=self.dtype
):
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
x_t = self._denoising_step(
x_t, t, t_prev, conditioning, cfg_weight, text_time
)
yield x_t
def generate_latents(
@@ -142,3 +167,100 @@ class StableDiffusion:
x = self.autoencoder.decode(x_t)
x = mx.clip(x / 2 + 0.5, 0, 1)
return x
class StableDiffusionXL(StableDiffusion):
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
super().__init__(model, float16)
self.sampler = SimpleEulerAncestralSampler(self.diffusion_config)
self.text_encoder_1 = self.text_encoder
self.tokenizer_1 = self.tokenizer
del self.tokenizer, self.text_encoder
self.text_encoder_2 = load_text_encoder(
model,
float16,
model_key="text_encoder_2",
)
self.tokenizer_2 = load_tokenizer(
model,
merges_key="tokenizer_2_merges",
vocab_key="tokenizer_2_vocab",
)
def ensure_models_are_loaded(self):
mx.eval(self.unet.parameters())
mx.eval(self.text_encoder_1.parameters())
mx.eval(self.text_encoder_2.parameters())
mx.eval(self.autoencoder.parameters())
def _get_text_conditioning(
self,
text: str,
n_images: int = 1,
cfg_weight: float = 7.5,
negative_text: str = "",
):
tokens_1 = self._tokenize(
self.tokenizer_1,
text,
(negative_text if cfg_weight > 1 else None),
)
tokens_2 = self._tokenize(
self.tokenizer_2,
text,
(negative_text if cfg_weight > 1 else None),
)
conditioning_1 = self.text_encoder_1(tokens_1)
conditioning_2 = self.text_encoder_2(tokens_2)
conditioning = mx.concatenate(
[conditioning_1.hidden_states[-2], conditioning_2.hidden_states[-2]],
axis=-1,
)
pooled_conditioning = conditioning_2.pooled_output
if n_images > 1:
conditioning = mx.repeat(conditioning, n_images, axis=0)
return conditioning, pooled_conditioning
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 2,
cfg_weight: float = 0.0,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Get the text conditioning
conditioning, pooled_conditioning = self._get_text_conditioning(
text, n_images, cfg_weight, negative_text
)
text_time = (
pooled_conditioning,
mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)),
)
# Create the latent variables
x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
)
# Perform the denoising loop
yield from self._denoising_loop(
x_T,
self.sampler.max_time,
conditioning,
num_steps,
cfg_weight,
text_time=text_time,
)

View File

@@ -1,15 +1,33 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
from .config import CLIPTextModelConfig
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPOutput:
# The last_hidden_state indexed at the EOS token and possibly projected if
# the model has a projection layer
pooled_output: Optional[mx.array] = None
# The full sequence output of the transformer after the final layernorm
last_hidden_state: Optional[mx.array] = None
# A list of hidden states corresponding to the outputs of the transformer layers
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int):
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
@@ -25,6 +43,8 @@ class CLIPEncoderLayer(nn.Module):
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
@@ -32,7 +52,7 @@ class CLIPEncoderLayer(nn.Module):
y = self.layer_norm2(x)
y = self.linear1(y)
y = nn.gelu_approx(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
@@ -48,23 +68,49 @@ class CLIPTextModel(nn.Module):
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads)
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
if config.projection_dim is not None:
self.text_projection = nn.Linear(
config.model_dims, config.projection_dim, bias=False
)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def __call__(self, x):
# Extract some shapes
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
mask = self._get_mask(N, x.dtype)
hidden_states = []
for l in self.layers:
x = l(x, mask)
hidden_states.append(x)
# Apply the final layernorm and return
return self.final_layer_norm(x)
x = self.final_layer_norm(x)
last_hidden_state = x
# Select the EOS token
pooled_output = x[mx.arange(len(x)), eos_tokens]
if "text_projection" in self:
pooled_output = self.text_projection(pooled_output)
return CLIPOutput(
pooled_output=pooled_output,
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
)

View File

@@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
@@ -23,6 +23,8 @@ class CLIPTextModelConfig:
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
projection_dim: Optional[int] = None
hidden_act: str = "quick_gelu"
@dataclass
@@ -38,6 +40,21 @@ class UNetConfig:
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
)
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
projection_class_embeddings_input_dim: Optional[int] = None
@dataclass

View File

@@ -1,13 +1,12 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import json
from functools import partial
from typing import Optional
import mlx.core as mx
import numpy as np
from huggingface_hub import hf_hub_download
from mlx.utils import tree_unflatten
from safetensors import safe_open as safetensor_open
from .clip import CLIPTextModel
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
@@ -17,6 +16,22 @@ from .vae import Autoencoder
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
_MODELS = {
# See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
"stabilityai/sdxl-turbo": {
"unet_config": "unet/config.json",
"unet": "unet/diffusion_pytorch_model.safetensors",
"text_encoder_config": "text_encoder/config.json",
"text_encoder": "text_encoder/model.safetensors",
"text_encoder_2_config": "text_encoder_2/config.json",
"text_encoder_2": "text_encoder_2/model.safetensors",
"vae_config": "vae/config.json",
"vae": "vae/diffusion_pytorch_model.safetensors",
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
"tokenizer_2_merges": "tokenizer_2/merges.txt",
},
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
"stabilityai/stable-diffusion-2-1-base": {
"unet_config": "unet/config.json",
@@ -28,14 +43,10 @@ _MODELS = {
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
}
},
}
def _from_numpy(x):
return mx.array(np.ascontiguousarray(x))
def map_unet_weights(key, value):
# Map up/downsampling
if "downsamplers" in key:
@@ -67,9 +78,9 @@ def map_unet_weights(key, value):
if "ff.net.0" in key:
k1 = key.replace("ff.net.0.proj", "linear1")
k2 = key.replace("ff.net.0.proj", "linear2")
v1, v2 = np.split(value, 2)
v1, v2 = mx.split(value, 2)
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
return [(k1, v1), (k2, v2)]
if "conv_shortcut.weight" in key:
value = value.squeeze()
@@ -80,8 +91,9 @@ def map_unet_weights(key, value):
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
return [(key, _from_numpy(value))]
return [(key, value)]
def map_clip_text_encoder_weights(key, value):
@@ -109,7 +121,7 @@ def map_clip_text_encoder_weights(key, value):
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
return [(key, _from_numpy(value))]
return [(key, value)]
def map_vae_weights(key, value):
@@ -148,8 +160,9 @@ def map_vae_weights(key, value):
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
return [(key, _from_numpy(value))]
return [(key, value)]
def _flatten(params):
@@ -157,9 +170,9 @@ def _flatten(params):
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
dtype = np.float16 if float16 else np.float32
with safetensor_open(weight_file, framework="numpy") as f:
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
dtype = mx.float16 if float16 else mx.float32
weights = mx.load(weight_file)
weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()])
model.update(tree_unflatten(weights))
@@ -186,6 +199,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks,
transformer_layers_per_block=config.get(
"transformer_layers_per_block", (1,) * 4
),
num_attention_heads=(
[config["attention_head_dim"]] * n_blocks
if isinstance(config["attention_head_dim"], int)
@@ -193,6 +209,13 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
),
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
down_block_types=config["down_block_types"],
up_block_types=config["up_block_types"][::-1],
addition_embed_type=config.get("addition_embed_type", None),
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
projection_class_embeddings_input_dim=config.get(
"projection_class_embeddings_input_dim", None
),
)
)
@@ -204,15 +227,24 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
return model
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_text_encoder(
key: str = _DEFAULT_MODEL,
float16: bool = False,
model_key: str = "text_encoder",
config_key: Optional[str] = None,
):
"""Load the stable diffusion text encoder from Hugging Face Hub."""
_check_key(key, "load_text_encoder")
config_key = config_key or (model_key + "_config")
# Download the config and create the model
text_encoder_config = _MODELS[key]["text_encoder_config"]
text_encoder_config = _MODELS[key][config_key]
with open(hf_hub_download(key, text_encoder_config)) as f:
config = json.load(f)
with_projection = "WithProjection" in config["architectures"][0]
model = CLIPTextModel(
CLIPTextModelConfig(
num_layers=config["num_hidden_layers"],
@@ -220,11 +252,13 @@ def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
projection_dim=config["projection_dim"] if with_projection else None,
hidden_act=config.get("hidden_act", "quick_gelu"),
)
)
# Download the weights and map them into the model
text_encoder_weights = _MODELS[key]["text_encoder"]
text_encoder_weights = _MODELS[key][model_key]
weight_file = hf_hub_download(key, text_encoder_weights)
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
@@ -249,6 +283,7 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
block_out_channels=config["block_out_channels"],
layers_per_block=config["layers_per_block"],
norm_num_groups=config["norm_num_groups"],
scaling_factor=config.get("scaling_factor", 0.18215),
)
)
@@ -276,14 +311,18 @@ def load_diffusion_config(key: str = _DEFAULT_MODEL):
)
def load_tokenizer(key: str = _DEFAULT_MODEL):
def load_tokenizer(
key: str = _DEFAULT_MODEL,
vocab_key: str = "tokenizer_vocab",
merges_key: str = "tokenizer_merges",
):
_check_key(key, "load_tokenizer")
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
vocab_file = hf_hub_download(key, _MODELS[key][vocab_key])
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
merges_file = hf_hub_download(key, _MODELS[key][merges_key])
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]

View File

@@ -83,3 +83,23 @@ class SimpleEulerSampler:
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev
class SimpleEulerAncestralSampler(SimpleEulerSampler):
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
sigma2 = sigma.square()
sigma_prev2 = sigma_prev.square()
sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
dt = sigma_down - sigma
x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
x_t_prev = x_t_prev + noise * sigma_up
x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
return x_t_prev

View File

@@ -74,7 +74,7 @@ class TransformerBlock(nn.Module):
y = self.norm3(x)
y_a = self.linear1(y)
y_b = self.linear2(y)
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
y = y_a * nn.gelu(y_b)
y = self.linear3(y)
x = x + y
@@ -106,10 +106,11 @@ class Transformer2D(nn.Module):
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
# Save the input to add to the output
input_x = x
dtype = x.dtype
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x).reshape(B, -1, C)
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
@@ -150,15 +151,17 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x, temb=None):
dtype = x.dtype
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x)
y = self.norm1(x.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y)
y = self.norm2(y.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv2(y)
@@ -292,6 +295,23 @@ class UNetModel(nn.Module):
config.block_out_channels[0] * 4,
)
if config.addition_embed_type == "text_time":
self.add_time_proj = nn.SinusoidalPositionalEncoding(
config.addition_time_embed_dim,
max_freq=1,
min_freq=math.exp(
-math.log(10000)
+ 2 * math.log(10000) / config.addition_time_embed_dim
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.add_embedding = TimestepEmbedding(
config.projection_class_embeddings_input_dim,
config.block_out_channels[0] * 4,
)
# Make the downsampling blocks
block_channels = [config.block_out_channels[0]] + list(
config.block_out_channels
@@ -308,7 +328,7 @@ class UNetModel(nn.Module):
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention=(i < len(config.block_out_channels) - 1),
add_cross_attention="CrossAttn" in config.down_block_types[i],
)
for i, (in_channels, out_channels) in enumerate(
zip(block_channels, block_channels[1:])
@@ -357,7 +377,7 @@ class UNetModel(nn.Module):
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention=(i < len(config.block_out_channels) - 1),
add_cross_attention="CrossAttn" in config.up_block_types[i],
)
for i, (in_channels, out_channels, prev_out_channels) in reversed(
list(
@@ -380,11 +400,27 @@ class UNetModel(nn.Module):
padding=(config.conv_out_kernel - 1) // 2,
)
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
def __call__(
self,
x,
timestep,
encoder_x,
attn_mask=None,
encoder_attn_mask=None,
text_time=None,
):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)
# Add the extra text_time conditioning
if text_time is not None:
text_emb, time_ids = text_time
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
emb = mx.concatenate([text_emb, emb], axis=-1)
emb = self.add_embedding(emb)
temb = temb + emb
# Preprocess the input
x = self.conv_in(x)
@@ -417,7 +453,8 @@ class UNetModel(nn.Module):
)
# Postprocess the output
x = self.conv_norm_out(x)
dtype = x.dtype
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
x = nn.silu(x)
x = self.conv_out(x)