mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-10-24 22:58:08 +08:00
Stable diffusion XL (#516)
This commit is contained in:
committed by
GitHub
parent
8c2cf665ed
commit
3a9e6c3f70
@@ -1,7 +1,7 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import time
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
@@ -13,7 +13,7 @@ from .model_io import (
|
||||
load_tokenizer,
|
||||
load_unet,
|
||||
)
|
||||
from .sampler import SimpleEulerSampler
|
||||
from .sampler import SimpleEulerAncestralSampler, SimpleEulerSampler
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
@@ -22,10 +22,27 @@ class StableDiffusion:
|
||||
self.diffusion_config = load_diffusion_config(model)
|
||||
self.unet = load_unet(model, float16)
|
||||
self.text_encoder = load_text_encoder(model, float16)
|
||||
self.autoencoder = load_autoencoder(model, float16)
|
||||
self.autoencoder = load_autoencoder(model, False)
|
||||
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||
self.tokenizer = load_tokenizer(model)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(self.unet.parameters())
|
||||
mx.eval(self.text_encoder.parameters())
|
||||
mx.eval(self.autoencoder.parameters())
|
||||
|
||||
def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None):
|
||||
# Tokenize the text
|
||||
tokens = [tokenizer.tokenize(text)]
|
||||
if negative_text is not None:
|
||||
tokens += [tokenizer.tokenize(negative_text)]
|
||||
lengths = [len(t) for t in tokens]
|
||||
N = max(lengths)
|
||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||
tokens = mx.array(tokens)
|
||||
|
||||
return tokens
|
||||
|
||||
def _get_text_conditioning(
|
||||
self,
|
||||
text: str,
|
||||
@@ -34,16 +51,12 @@ class StableDiffusion:
|
||||
negative_text: str = "",
|
||||
):
|
||||
# Tokenize the text
|
||||
tokens = [self.tokenizer.tokenize(text)]
|
||||
if cfg_weight > 1:
|
||||
tokens += [self.tokenizer.tokenize(negative_text)]
|
||||
lengths = [len(t) for t in tokens]
|
||||
N = max(lengths)
|
||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||
tokens = mx.array(tokens)
|
||||
tokens = self._tokenize(
|
||||
self.tokenizer, text, (negative_text if cfg_weight > 1 else None)
|
||||
)
|
||||
|
||||
# Compute the features
|
||||
conditioning = self.text_encoder(tokens)
|
||||
conditioning = self.text_encoder(tokens).last_hidden_state
|
||||
|
||||
# Repeat the conditioning for each of the generated images
|
||||
if n_images > 1:
|
||||
@@ -51,10 +64,14 @@ class StableDiffusion:
|
||||
|
||||
return conditioning
|
||||
|
||||
def _denoising_step(self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5):
|
||||
def _denoising_step(
|
||||
self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5, text_time=None
|
||||
):
|
||||
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
||||
eps_pred = self.unet(
|
||||
x_t_unet, t_unet, encoder_x=conditioning, text_time=text_time
|
||||
)
|
||||
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = eps_pred.split(2)
|
||||
@@ -65,13 +82,21 @@ class StableDiffusion:
|
||||
return x_t_prev
|
||||
|
||||
def _denoising_loop(
|
||||
self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5
|
||||
self,
|
||||
x_T,
|
||||
T,
|
||||
conditioning,
|
||||
num_steps: int = 50,
|
||||
cfg_weight: float = 7.5,
|
||||
text_time=None,
|
||||
):
|
||||
x_t = x_T
|
||||
for t, t_prev in self.sampler.timesteps(
|
||||
num_steps, start_time=T, dtype=self.dtype
|
||||
):
|
||||
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
|
||||
x_t = self._denoising_step(
|
||||
x_t, t, t_prev, conditioning, cfg_weight, text_time
|
||||
)
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
@@ -142,3 +167,100 @@ class StableDiffusion:
|
||||
x = self.autoencoder.decode(x_t)
|
||||
x = mx.clip(x / 2 + 0.5, 0, 1)
|
||||
return x
|
||||
|
||||
|
||||
class StableDiffusionXL(StableDiffusion):
|
||||
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
super().__init__(model, float16)
|
||||
|
||||
self.sampler = SimpleEulerAncestralSampler(self.diffusion_config)
|
||||
|
||||
self.text_encoder_1 = self.text_encoder
|
||||
self.tokenizer_1 = self.tokenizer
|
||||
del self.tokenizer, self.text_encoder
|
||||
|
||||
self.text_encoder_2 = load_text_encoder(
|
||||
model,
|
||||
float16,
|
||||
model_key="text_encoder_2",
|
||||
)
|
||||
self.tokenizer_2 = load_tokenizer(
|
||||
model,
|
||||
merges_key="tokenizer_2_merges",
|
||||
vocab_key="tokenizer_2_vocab",
|
||||
)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(self.unet.parameters())
|
||||
mx.eval(self.text_encoder_1.parameters())
|
||||
mx.eval(self.text_encoder_2.parameters())
|
||||
mx.eval(self.autoencoder.parameters())
|
||||
|
||||
def _get_text_conditioning(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
cfg_weight: float = 7.5,
|
||||
negative_text: str = "",
|
||||
):
|
||||
tokens_1 = self._tokenize(
|
||||
self.tokenizer_1,
|
||||
text,
|
||||
(negative_text if cfg_weight > 1 else None),
|
||||
)
|
||||
tokens_2 = self._tokenize(
|
||||
self.tokenizer_2,
|
||||
text,
|
||||
(negative_text if cfg_weight > 1 else None),
|
||||
)
|
||||
|
||||
conditioning_1 = self.text_encoder_1(tokens_1)
|
||||
conditioning_2 = self.text_encoder_2(tokens_2)
|
||||
conditioning = mx.concatenate(
|
||||
[conditioning_1.hidden_states[-2], conditioning_2.hidden_states[-2]],
|
||||
axis=-1,
|
||||
)
|
||||
pooled_conditioning = conditioning_2.pooled_output
|
||||
|
||||
if n_images > 1:
|
||||
conditioning = mx.repeat(conditioning, n_images, axis=0)
|
||||
|
||||
return conditioning, pooled_conditioning
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 2,
|
||||
cfg_weight: float = 0.0,
|
||||
negative_text: str = "",
|
||||
latent_size: Tuple[int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Get the text conditioning
|
||||
conditioning, pooled_conditioning = self._get_text_conditioning(
|
||||
text, n_images, cfg_weight, negative_text
|
||||
)
|
||||
text_time = (
|
||||
pooled_conditioning,
|
||||
mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)),
|
||||
)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior(
|
||||
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
|
||||
)
|
||||
|
||||
# Perform the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T,
|
||||
self.sampler.max_time,
|
||||
conditioning,
|
||||
num_steps,
|
||||
cfg_weight,
|
||||
text_time=text_time,
|
||||
)
|
||||
|
||||
@@ -1,15 +1,33 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import CLIPTextModelConfig
|
||||
|
||||
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPOutput:
|
||||
# The last_hidden_state indexed at the EOS token and possibly projected if
|
||||
# the model has a projection layer
|
||||
pooled_output: Optional[mx.array] = None
|
||||
|
||||
# The full sequence output of the transformer after the final layernorm
|
||||
last_hidden_state: Optional[mx.array] = None
|
||||
|
||||
# A list of hidden states corresponding to the outputs of the transformer layers
|
||||
hidden_states: Optional[List[mx.array]] = None
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int):
|
||||
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
@@ -25,6 +43,8 @@ class CLIPEncoderLayer(nn.Module):
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
self.act = _ACTIVATIONS[activation]
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
@@ -32,7 +52,7 @@ class CLIPEncoderLayer(nn.Module):
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = nn.gelu_approx(y)
|
||||
y = self.act(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
|
||||
@@ -48,23 +68,49 @@ class CLIPTextModel(nn.Module):
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(config.model_dims, config.num_heads)
|
||||
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
if config.projection_dim is not None:
|
||||
self.text_projection = nn.Linear(
|
||||
config.model_dims, config.projection_dim, bias=False
|
||||
)
|
||||
|
||||
def _get_mask(self, N, dtype):
|
||||
indices = mx.arange(N)
|
||||
mask = indices[:, None] < indices[None]
|
||||
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
||||
return mask
|
||||
|
||||
def __call__(self, x):
|
||||
# Extract some shapes
|
||||
B, N = x.shape
|
||||
eos_tokens = x.argmax(-1)
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
|
||||
# Compute the features from the transformer
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
|
||||
mask = self._get_mask(N, x.dtype)
|
||||
hidden_states = []
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
hidden_states.append(x)
|
||||
|
||||
# Apply the final layernorm and return
|
||||
return self.final_layer_norm(x)
|
||||
x = self.final_layer_norm(x)
|
||||
last_hidden_state = x
|
||||
|
||||
# Select the EOS token
|
||||
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
||||
if "text_projection" in self:
|
||||
pooled_output = self.text_projection(pooled_output)
|
||||
|
||||
return CLIPOutput(
|
||||
pooled_output=pooled_output,
|
||||
last_hidden_state=last_hidden_state,
|
||||
hidden_states=hidden_states,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
@@ -23,6 +23,8 @@ class CLIPTextModelConfig:
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
projection_dim: Optional[int] = None
|
||||
hidden_act: str = "quick_gelu"
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -38,6 +40,21 @@ class UNetConfig:
|
||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||
norm_num_groups: int = 32
|
||||
down_block_types: Tuple[str] = (
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"CrossAttnDownBlock2D",
|
||||
"DownBlock2D",
|
||||
)
|
||||
up_block_types: Tuple[str] = (
|
||||
"UpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
"CrossAttnUpBlock2D",
|
||||
)
|
||||
addition_embed_type: Optional[str] = None
|
||||
addition_time_embed_dim: Optional[int] = None
|
||||
projection_class_embeddings_input_dim: Optional[int] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
# Copyright © 2023 Apple Inc.
|
||||
# Copyright © 2023-2024 Apple Inc.
|
||||
|
||||
import json
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from mlx.utils import tree_unflatten
|
||||
from safetensors import safe_open as safetensor_open
|
||||
|
||||
from .clip import CLIPTextModel
|
||||
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
||||
@@ -17,6 +16,22 @@ from .vae import Autoencoder
|
||||
|
||||
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
||||
_MODELS = {
|
||||
# See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
|
||||
"stabilityai/sdxl-turbo": {
|
||||
"unet_config": "unet/config.json",
|
||||
"unet": "unet/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder_config": "text_encoder/config.json",
|
||||
"text_encoder": "text_encoder/model.safetensors",
|
||||
"text_encoder_2_config": "text_encoder_2/config.json",
|
||||
"text_encoder_2": "text_encoder_2/model.safetensors",
|
||||
"vae_config": "vae/config.json",
|
||||
"vae": "vae/diffusion_pytorch_model.safetensors",
|
||||
"diffusion_config": "scheduler/scheduler_config.json",
|
||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||
"tokenizer_merges": "tokenizer/merges.txt",
|
||||
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
|
||||
"tokenizer_2_merges": "tokenizer_2/merges.txt",
|
||||
},
|
||||
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
||||
"stabilityai/stable-diffusion-2-1-base": {
|
||||
"unet_config": "unet/config.json",
|
||||
@@ -28,14 +43,10 @@ _MODELS = {
|
||||
"diffusion_config": "scheduler/scheduler_config.json",
|
||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||
"tokenizer_merges": "tokenizer/merges.txt",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def _from_numpy(x):
|
||||
return mx.array(np.ascontiguousarray(x))
|
||||
|
||||
|
||||
def map_unet_weights(key, value):
|
||||
# Map up/downsampling
|
||||
if "downsamplers" in key:
|
||||
@@ -67,9 +78,9 @@ def map_unet_weights(key, value):
|
||||
if "ff.net.0" in key:
|
||||
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||
v1, v2 = np.split(value, 2)
|
||||
v1, v2 = mx.split(value, 2)
|
||||
|
||||
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
|
||||
return [(k1, v1), (k2, v2)]
|
||||
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
@@ -80,8 +91,9 @@ def map_unet_weights(key, value):
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
value = value.reshape(-1).reshape(value.shape)
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
return [(key, value)]
|
||||
|
||||
|
||||
def map_clip_text_encoder_weights(key, value):
|
||||
@@ -109,7 +121,7 @@ def map_clip_text_encoder_weights(key, value):
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
return [(key, value)]
|
||||
|
||||
|
||||
def map_vae_weights(key, value):
|
||||
@@ -148,8 +160,9 @@ def map_vae_weights(key, value):
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
value = value.reshape(-1).reshape(value.shape)
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
return [(key, value)]
|
||||
|
||||
|
||||
def _flatten(params):
|
||||
@@ -157,9 +170,9 @@ def _flatten(params):
|
||||
|
||||
|
||||
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
||||
dtype = np.float16 if float16 else np.float32
|
||||
with safetensor_open(weight_file, framework="numpy") as f:
|
||||
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
||||
dtype = mx.float16 if float16 else mx.float32
|
||||
weights = mx.load(weight_file)
|
||||
weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()])
|
||||
model.update(tree_unflatten(weights))
|
||||
|
||||
|
||||
@@ -186,6 +199,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
out_channels=config["out_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||
transformer_layers_per_block=config.get(
|
||||
"transformer_layers_per_block", (1,) * 4
|
||||
),
|
||||
num_attention_heads=(
|
||||
[config["attention_head_dim"]] * n_blocks
|
||||
if isinstance(config["attention_head_dim"], int)
|
||||
@@ -193,6 +209,13 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
),
|
||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
down_block_types=config["down_block_types"],
|
||||
up_block_types=config["up_block_types"][::-1],
|
||||
addition_embed_type=config.get("addition_embed_type", None),
|
||||
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
|
||||
projection_class_embeddings_input_dim=config.get(
|
||||
"projection_class_embeddings_input_dim", None
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -204,15 +227,24 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
return model
|
||||
|
||||
|
||||
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
def load_text_encoder(
|
||||
key: str = _DEFAULT_MODEL,
|
||||
float16: bool = False,
|
||||
model_key: str = "text_encoder",
|
||||
config_key: Optional[str] = None,
|
||||
):
|
||||
"""Load the stable diffusion text encoder from Hugging Face Hub."""
|
||||
_check_key(key, "load_text_encoder")
|
||||
|
||||
config_key = config_key or (model_key + "_config")
|
||||
|
||||
# Download the config and create the model
|
||||
text_encoder_config = _MODELS[key]["text_encoder_config"]
|
||||
text_encoder_config = _MODELS[key][config_key]
|
||||
with open(hf_hub_download(key, text_encoder_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
with_projection = "WithProjection" in config["architectures"][0]
|
||||
|
||||
model = CLIPTextModel(
|
||||
CLIPTextModelConfig(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
@@ -220,11 +252,13 @@ def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
projection_dim=config["projection_dim"] if with_projection else None,
|
||||
hidden_act=config.get("hidden_act", "quick_gelu"),
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
text_encoder_weights = _MODELS[key]["text_encoder"]
|
||||
text_encoder_weights = _MODELS[key][model_key]
|
||||
weight_file = hf_hub_download(key, text_encoder_weights)
|
||||
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
||||
|
||||
@@ -249,6 +283,7 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=config["layers_per_block"],
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
scaling_factor=config.get("scaling_factor", 0.18215),
|
||||
)
|
||||
)
|
||||
|
||||
@@ -276,14 +311,18 @@ def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
||||
)
|
||||
|
||||
|
||||
def load_tokenizer(key: str = _DEFAULT_MODEL):
|
||||
def load_tokenizer(
|
||||
key: str = _DEFAULT_MODEL,
|
||||
vocab_key: str = "tokenizer_vocab",
|
||||
merges_key: str = "tokenizer_merges",
|
||||
):
|
||||
_check_key(key, "load_tokenizer")
|
||||
|
||||
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
|
||||
vocab_file = hf_hub_download(key, _MODELS[key][vocab_key])
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
|
||||
merges_file = hf_hub_download(key, _MODELS[key][merges_key])
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
|
||||
@@ -83,3 +83,23 @@ class SimpleEulerSampler:
|
||||
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||
|
||||
return x_t_prev
|
||||
|
||||
|
||||
class SimpleEulerAncestralSampler(SimpleEulerSampler):
|
||||
def step(self, eps_pred, x_t, t, t_prev):
|
||||
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||
|
||||
sigma2 = sigma.square()
|
||||
sigma_prev2 = sigma_prev.square()
|
||||
sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
|
||||
sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
|
||||
|
||||
dt = sigma_down - sigma
|
||||
x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
|
||||
noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
|
||||
x_t_prev = x_t_prev + noise * sigma_up
|
||||
|
||||
x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
|
||||
|
||||
return x_t_prev
|
||||
|
||||
@@ -74,7 +74,7 @@ class TransformerBlock(nn.Module):
|
||||
y = self.norm3(x)
|
||||
y_a = self.linear1(y)
|
||||
y_b = self.linear2(y)
|
||||
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
|
||||
y = y_a * nn.gelu(y_b)
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
@@ -106,10 +106,11 @@ class Transformer2D(nn.Module):
|
||||
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||
# Save the input to add to the output
|
||||
input_x = x
|
||||
dtype = x.dtype
|
||||
|
||||
# Perform the input norm and projection
|
||||
B, H, W, C = x.shape
|
||||
x = self.norm(x).reshape(B, -1, C)
|
||||
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
|
||||
x = self.proj_in(x)
|
||||
|
||||
# Apply the transformer
|
||||
@@ -150,15 +151,17 @@ class ResnetBlock2D(nn.Module):
|
||||
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def __call__(self, x, temb=None):
|
||||
dtype = x.dtype
|
||||
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(nn.silu(temb))
|
||||
|
||||
y = self.norm1(x)
|
||||
y = self.norm1(x.astype(mx.float32)).astype(dtype)
|
||||
y = nn.silu(y)
|
||||
y = self.conv1(y)
|
||||
if temb is not None:
|
||||
y = y + temb[:, None, None, :]
|
||||
y = self.norm2(y)
|
||||
y = self.norm2(y.astype(mx.float32)).astype(dtype)
|
||||
y = nn.silu(y)
|
||||
y = self.conv2(y)
|
||||
|
||||
@@ -292,6 +295,23 @@ class UNetModel(nn.Module):
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
if config.addition_embed_type == "text_time":
|
||||
self.add_time_proj = nn.SinusoidalPositionalEncoding(
|
||||
config.addition_time_embed_dim,
|
||||
max_freq=1,
|
||||
min_freq=math.exp(
|
||||
-math.log(10000)
|
||||
+ 2 * math.log(10000) / config.addition_time_embed_dim
|
||||
),
|
||||
scale=1.0,
|
||||
cos_first=True,
|
||||
full_turns=False,
|
||||
)
|
||||
self.add_embedding = TimestepEmbedding(
|
||||
config.projection_class_embeddings_input_dim,
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
# Make the downsampling blocks
|
||||
block_channels = [config.block_out_channels[0]] + list(
|
||||
config.block_out_channels
|
||||
@@ -308,7 +328,7 @@ class UNetModel(nn.Module):
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||
add_upsample=False,
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
add_cross_attention="CrossAttn" in config.down_block_types[i],
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(
|
||||
zip(block_channels, block_channels[1:])
|
||||
@@ -357,7 +377,7 @@ class UNetModel(nn.Module):
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=(i > 0),
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
add_cross_attention="CrossAttn" in config.up_block_types[i],
|
||||
)
|
||||
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
||||
list(
|
||||
@@ -380,11 +400,27 @@ class UNetModel(nn.Module):
|
||||
padding=(config.conv_out_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
timestep,
|
||||
encoder_x,
|
||||
attn_mask=None,
|
||||
encoder_attn_mask=None,
|
||||
text_time=None,
|
||||
):
|
||||
# Compute the time embeddings
|
||||
temb = self.timesteps(timestep).astype(x.dtype)
|
||||
temb = self.time_embedding(temb)
|
||||
|
||||
# Add the extra text_time conditioning
|
||||
if text_time is not None:
|
||||
text_emb, time_ids = text_time
|
||||
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
|
||||
emb = mx.concatenate([text_emb, emb], axis=-1)
|
||||
emb = self.add_embedding(emb)
|
||||
temb = temb + emb
|
||||
|
||||
# Preprocess the input
|
||||
x = self.conv_in(x)
|
||||
|
||||
@@ -417,7 +453,8 @@ class UNetModel(nn.Module):
|
||||
)
|
||||
|
||||
# Postprocess the output
|
||||
x = self.conv_norm_out(x)
|
||||
dtype = x.dtype
|
||||
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user