mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 20:04:38 +08:00
Add the Llama and Stable Diffusion examples
This commit is contained in:
95
stable_diffusion/README.md
Normal file
95
stable_diffusion/README.md
Normal file
@@ -0,0 +1,95 @@
|
||||
Stable Diffusion
|
||||
================
|
||||
|
||||
Stable Diffusion in MLX. The implementation was ported from Hugginface's
|
||||
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
|
||||
and using the weights available on the Huggingface Hub by Stability AI at
|
||||
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
|
||||
|
||||

|
||||
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign saying MLX in capital letters.'*
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
The dependencies are minimal, namely:
|
||||
|
||||
- `safetensors` and `huggingface-hub` to load the checkpoints.
|
||||
- `regex` for the tokenization
|
||||
- `numpy` because safetensors needs to return some form of array
|
||||
- `tqdm` and `PIL` for the `txt2image.py` script
|
||||
|
||||
You can install all of the above with the `requirements.txt` as follows:
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
Usage
|
||||
------
|
||||
|
||||
Although each component in this repository can be used by itsself, the fastest
|
||||
way to get started is by using the `StableDiffusion` class from the `diffusion`
|
||||
module.
|
||||
|
||||
```python
|
||||
from stable_diffusion import StableDiffusion
|
||||
|
||||
# This will download all the weights from HF hub and load the models in
|
||||
# memory
|
||||
sd = StableDiffusion()
|
||||
|
||||
# This creates a python generator that returns the latent produced by the
|
||||
# reverse diffusion process.
|
||||
#
|
||||
# Because MLX is lazily evaluated iterating over this generator doesn't
|
||||
# actually perform the computation until mx.eval() is called.
|
||||
latent_generator = sd.generate_latents("A photo of an astronaut riding a horse on Mars.")
|
||||
|
||||
# Here we are evaluating each diffusion step but we could also evaluate
|
||||
# once at the end.
|
||||
for x_t in latent_generator:
|
||||
mx.simplify(x_t) # remove possible redundant computation eg reuse
|
||||
# scalars etc
|
||||
mx.eval(x_t)
|
||||
|
||||
# Now x_t is the last latent from the reverse process aka x_0. We can
|
||||
# decode it into an image using the stable diffusion VAE.
|
||||
im = sd.decode(x_t)
|
||||
```
|
||||
|
||||
The above is almost line for line the implementation of the `txt2image.py`
|
||||
script in the root of the repository. You can use the script as follows:
|
||||
|
||||
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
|
||||
|
||||
Performance
|
||||
-----------
|
||||
|
||||
The following table compares the performance of the UNet in stable diffusion.
|
||||
We report throughput in images per second for the provided `txt2image.py`
|
||||
script and the `diffusers` library using the MPS PyTorch backend.
|
||||
|
||||
At the time of writing this comparison convolutions are still some of the least
|
||||
optimized operations in MLX. Despite that, MLX still achieves **~40% higher
|
||||
throughput** than PyTorch with a batch size of 16 and ~15% higher when
|
||||
comparing the optimal batch sizes.
|
||||
|
||||
Notably, PyTorch achieves almost ~50% higher throughput for the batch size of 1
|
||||
which is unfortunate as that means that a single image can be computed faster.
|
||||
However, when starting with the models not loaded in memory and PyTorch's MPS
|
||||
graph kernels not cached, the compilation time more than accounts for this
|
||||
speed difference.
|
||||
|
||||
| Batch size | PyTorch | MLX |
|
||||
| ---------- | ----------- | ----------- |
|
||||
| 1 | 6.25 im/s | 4.17 im/s |
|
||||
| 2 | 7.14 im/s | 5.88 im/s |
|
||||
| 4 |**7.69 im/s**| 7.14 im/s |
|
||||
| 6 | 7.22 im/s | 8.00 im/s |
|
||||
| 8 | 6.89 im/s | 8.42 im/s |
|
||||
| 12 | 6.62 im/s | 8.51 im/s |
|
||||
| 16 | 6.32 im/s |**8.79 im/s**|
|
||||
|
||||
The above experiments were made on an M2 Ultra with PyTorch version 2.1,
|
||||
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
|
||||
used classifier free guidance which means that the above batch sizes result
|
||||
double the images processed by the UNet.
|
BIN
stable_diffusion/generated-mlx.png
Normal file
BIN
stable_diffusion/generated-mlx.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 185 KiB |
6
stable_diffusion/requirements.txt
Normal file
6
stable_diffusion/requirements.txt
Normal file
@@ -0,0 +1,6 @@
|
||||
safetensors
|
||||
huggingface-hub
|
||||
regex
|
||||
numpy
|
||||
tqdm
|
||||
Pillow
|
96
stable_diffusion/stable_diffusion/__init__.py
Normal file
96
stable_diffusion/stable_diffusion/__init__.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from .model_io import (
|
||||
load_unet,
|
||||
load_text_encoder,
|
||||
load_autoencoder,
|
||||
load_diffusion_config,
|
||||
load_tokenizer,
|
||||
_DEFAULT_MODEL,
|
||||
)
|
||||
from .sampler import SimpleEulerSampler
|
||||
|
||||
|
||||
def _repeat(x, n, axis):
|
||||
# Make the expanded shape
|
||||
s = x.shape
|
||||
s.insert(axis + 1, n)
|
||||
|
||||
# Expand
|
||||
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
|
||||
|
||||
# Make the flattened shape
|
||||
s.pop(axis + 1)
|
||||
s[axis] *= n
|
||||
|
||||
return x.reshape(s)
|
||||
|
||||
|
||||
class StableDiffusion:
|
||||
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
self.dtype = mx.float16 if float16 else mx.float32
|
||||
self.diffusion_config = load_diffusion_config(model)
|
||||
self.unet = load_unet(model, float16)
|
||||
self.text_encoder = load_text_encoder(model, float16)
|
||||
self.autoencoder = load_autoencoder(model, float16)
|
||||
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||
self.tokenizer = load_tokenizer(model)
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 50,
|
||||
cfg_weight: float = 7.5,
|
||||
negative_text: str = "",
|
||||
latent_size: Tuple[int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = seed or int(time.time())
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Tokenize the text
|
||||
tokens = [self.tokenizer.tokenize(text)]
|
||||
if cfg_weight > 1:
|
||||
tokens += [self.tokenizer.tokenize(negative_text)]
|
||||
lengths = [len(t) for t in tokens]
|
||||
N = max(lengths)
|
||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||
tokens = mx.array(tokens)
|
||||
|
||||
# Compute the features
|
||||
conditioning = self.text_encoder(tokens)
|
||||
|
||||
# Repeat the conditioning for each of the generated images
|
||||
if n_images > 1:
|
||||
conditioning = _repeat(conditioning, n_images, axis=0)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior(
|
||||
(n_images, *latent_size, self.autoencoder.latent_channels),
|
||||
dtype=self.dtype
|
||||
)
|
||||
|
||||
# Perform the denoising loop
|
||||
x_t = x_T
|
||||
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
|
||||
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
||||
|
||||
if cfg_weight > 1:
|
||||
eps_text, eps_neg = eps_pred.split(2)
|
||||
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
|
||||
|
||||
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
|
||||
x_t = x_t_prev
|
||||
yield x_t
|
||||
|
||||
def decode(self, x_t):
|
||||
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
|
||||
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
|
||||
return x
|
68
stable_diffusion/stable_diffusion/clip.py
Normal file
68
stable_diffusion/stable_diffusion/clip.py
Normal file
@@ -0,0 +1,68 @@
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import CLIPTextModelConfig
|
||||
|
||||
|
||||
class CLIPEncoderLayer(nn.Module):
|
||||
"""The transformer encoder layer from CLIP."""
|
||||
|
||||
def __init__(self, model_dims: int, num_heads: int):
|
||||
super().__init__()
|
||||
|
||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||
self.layer_norm2 = nn.LayerNorm(model_dims)
|
||||
|
||||
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
# Add biases to the attention projections to match CLIP
|
||||
self.attention.query_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.key_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.value_proj.bias = mx.zeros(model_dims)
|
||||
self.attention.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||
|
||||
def __call__(self, x, attn_mask=None):
|
||||
y = self.layer_norm1(x)
|
||||
y = self.attention(y, y, y, attn_mask)
|
||||
x = y + x
|
||||
|
||||
y = self.layer_norm2(x)
|
||||
y = self.linear1(y)
|
||||
y = nn.gelu_approx(y)
|
||||
y = self.linear2(y)
|
||||
x = y + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class CLIPTextModel(nn.Module):
|
||||
"""Implements the text encoder transformer from CLIP."""
|
||||
|
||||
def __init__(self, config: CLIPTextModelConfig):
|
||||
super().__init__()
|
||||
|
||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||
self.layers = [
|
||||
CLIPEncoderLayer(config.model_dims, config.num_heads)
|
||||
for i in range(config.num_layers)
|
||||
]
|
||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||
|
||||
def __call__(self, x):
|
||||
# Extract some shapes
|
||||
B, N = x.shape
|
||||
|
||||
# Compute the embeddings
|
||||
x = self.token_embedding(x)
|
||||
x = x + self.position_embedding.weight[:N]
|
||||
|
||||
# Compute the features from the transformer
|
||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
|
||||
for l in self.layers:
|
||||
x = l(x, mask)
|
||||
|
||||
# Apply the final layernorm and return
|
||||
return self.final_layer_norm(x)
|
46
stable_diffusion/stable_diffusion/config.py
Normal file
46
stable_diffusion/stable_diffusion/config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class AutoencoderConfig:
|
||||
in_channels: int = 3
|
||||
out_channels: int = 3
|
||||
latent_channels_out: int = 8
|
||||
latent_channels_in: int = 4
|
||||
block_out_channels: Tuple[int] = (128, 256, 512, 512)
|
||||
layers_per_block: int = 2
|
||||
norm_num_groups: int = 32
|
||||
scaling_factor: float = 0.18215
|
||||
|
||||
|
||||
@dataclass
|
||||
class CLIPTextModelConfig:
|
||||
num_layers: int = 23
|
||||
model_dims: int = 1024
|
||||
num_heads: int = 16
|
||||
max_length: int = 77
|
||||
vocab_size: int = 49408
|
||||
|
||||
|
||||
@dataclass
|
||||
class UNetConfig:
|
||||
in_channels: int = 4
|
||||
out_channels: int = 4
|
||||
conv_in_kernel: int = 3
|
||||
conv_out_kernel: int = 3
|
||||
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
|
||||
layers_per_block: Tuple[int] = (2, 2, 2, 2)
|
||||
mid_block_layers: int = 2
|
||||
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
|
||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||
norm_num_groups: int = 32
|
||||
|
||||
|
||||
@dataclass
|
||||
class DiffusionConfig:
|
||||
beta_schedule: str = "scaled_linear"
|
||||
beta_start: float = 0.00085
|
||||
beta_end: float = 0.012
|
||||
num_train_steps: int = 1000
|
284
stable_diffusion/stable_diffusion/model_io.py
Normal file
284
stable_diffusion/stable_diffusion/model_io.py
Normal file
@@ -0,0 +1,284 @@
|
||||
import json
|
||||
from functools import partial
|
||||
|
||||
import numpy as np
|
||||
from huggingface_hub import hf_hub_download
|
||||
from safetensors import safe_open as safetensor_open
|
||||
|
||||
import mlx.core as mx
|
||||
from mlx.utils import tree_unflatten
|
||||
|
||||
from .clip import CLIPTextModel
|
||||
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
|
||||
from .tokenizer import Tokenizer
|
||||
from .unet import UNetModel
|
||||
from .vae import Autoencoder
|
||||
|
||||
|
||||
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
||||
_MODELS = {
|
||||
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
||||
"stabilityai/stable-diffusion-2-1-base": {
|
||||
"unet_config": "unet/config.json",
|
||||
"unet": "unet/diffusion_pytorch_model.safetensors",
|
||||
"text_encoder_config": "text_encoder/config.json",
|
||||
"text_encoder": "text_encoder/model.safetensors",
|
||||
"vae_config": "vae/config.json",
|
||||
"vae": "vae/diffusion_pytorch_model.safetensors",
|
||||
"diffusion_config": "scheduler/scheduler_config.json",
|
||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||
"tokenizer_merges": "tokenizer/merges.txt",
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _from_numpy(x):
|
||||
return mx.array(np.ascontiguousarray(x))
|
||||
|
||||
|
||||
def map_unet_weights(key, value):
|
||||
# Map up/downsampling
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map attention layers
|
||||
if "to_k" in key:
|
||||
key = key.replace("to_k", "key_proj")
|
||||
if "to_out.0" in key:
|
||||
key = key.replace("to_out.0", "out_proj")
|
||||
if "to_q" in key:
|
||||
key = key.replace("to_q", "query_proj")
|
||||
if "to_v" in key:
|
||||
key = key.replace("to_v", "value_proj")
|
||||
|
||||
# Map transformer ffn
|
||||
if "ff.net.2" in key:
|
||||
key = key.replace("ff.net.2", "linear3")
|
||||
if "ff.net.0" in key:
|
||||
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||
v1, v2 = np.split(value, 2)
|
||||
|
||||
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
|
||||
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def map_clip_text_encoder_weights(key, value):
|
||||
# Remove prefixes
|
||||
if key.startswith("text_model."):
|
||||
key = key[11:]
|
||||
if key.startswith("embeddings."):
|
||||
key = key[11:]
|
||||
if key.startswith("encoder."):
|
||||
key = key[8:]
|
||||
|
||||
# Map attention layers
|
||||
if "self_attn." in key:
|
||||
key = key.replace("self_attn.", "attention.")
|
||||
if "q_proj." in key:
|
||||
key = key.replace("q_proj.", "query_proj.")
|
||||
if "k_proj." in key:
|
||||
key = key.replace("k_proj.", "key_proj.")
|
||||
if "v_proj." in key:
|
||||
key = key.replace("v_proj.", "value_proj.")
|
||||
|
||||
# Map ffn layers
|
||||
if "mlp.fc1" in key:
|
||||
key = key.replace("mlp.fc1", "linear1")
|
||||
if "mlp.fc2" in key:
|
||||
key = key.replace("mlp.fc2", "linear2")
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def map_vae_weights(key, value):
|
||||
# Map up/downsampling
|
||||
if "downsamplers" in key:
|
||||
key = key.replace("downsamplers.0.conv", "downsample")
|
||||
if "upsamplers" in key:
|
||||
key = key.replace("upsamplers.0.conv", "upsample")
|
||||
|
||||
# Map attention layers
|
||||
if "to_k" in key:
|
||||
key = key.replace("to_k", "key_proj")
|
||||
if "to_out.0" in key:
|
||||
key = key.replace("to_out.0", "out_proj")
|
||||
if "to_q" in key:
|
||||
key = key.replace("to_q", "query_proj")
|
||||
if "to_v" in key:
|
||||
key = key.replace("to_v", "value_proj")
|
||||
|
||||
# Map the mid block
|
||||
if "mid_block.resnets.0" in key:
|
||||
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
|
||||
if "mid_block.attentions.0" in key:
|
||||
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
|
||||
if "mid_block.resnets.1" in key:
|
||||
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
|
||||
|
||||
# Map the quant/post_quant layers
|
||||
if "quant_conv" in key:
|
||||
key = key.replace("quant_conv", "quant_proj")
|
||||
value = value.squeeze()
|
||||
|
||||
# Map the conv_shortcut to linear
|
||||
if "conv_shortcut.weight" in key:
|
||||
value = value.squeeze()
|
||||
|
||||
if len(value.shape) == 4:
|
||||
value = value.transpose(0, 2, 3, 1)
|
||||
|
||||
return [(key, _from_numpy(value))]
|
||||
|
||||
|
||||
def _flatten(params):
|
||||
return [(k, v) for p in params for (k, v) in p]
|
||||
|
||||
|
||||
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
||||
dtype = np.float16 if float16 else np.float32
|
||||
with safetensor_open(weight_file, framework="numpy") as f:
|
||||
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
||||
model.update(tree_unflatten(weights))
|
||||
|
||||
|
||||
def _check_key(key: str, part: str):
|
||||
if key not in _MODELS:
|
||||
raise ValueError(
|
||||
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
|
||||
)
|
||||
|
||||
|
||||
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion UNet from Huggingface Hub."""
|
||||
_check_key(key, "load_unet")
|
||||
|
||||
# Download the config and create the model
|
||||
unet_config = _MODELS[key]["unet_config"]
|
||||
with open(hf_hub_download(key, unet_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
n_blocks = len(config["block_out_channels"])
|
||||
model = UNetModel(
|
||||
UNetConfig(
|
||||
in_channels=config["in_channels"],
|
||||
out_channels=config["out_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||
num_attention_heads=config["attention_head_dim"],
|
||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
unet_weights = _MODELS[key]["unet"]
|
||||
weight_file = hf_hub_download(key, unet_weights)
|
||||
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion text encoder from Huggingface Hub."""
|
||||
_check_key(key, "load_text_encoder")
|
||||
|
||||
# Download the config and create the model
|
||||
text_encoder_config = _MODELS[key]["text_encoder_config"]
|
||||
with open(hf_hub_download(key, text_encoder_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = CLIPTextModel(
|
||||
CLIPTextModelConfig(
|
||||
num_layers=config["num_hidden_layers"],
|
||||
model_dims=config["hidden_size"],
|
||||
num_heads=config["num_attention_heads"],
|
||||
max_length=config["max_position_embeddings"],
|
||||
vocab_size=config["vocab_size"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
text_encoder_weights = _MODELS[key]["text_encoder"]
|
||||
weight_file = hf_hub_download(key, text_encoder_weights)
|
||||
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
||||
"""Load the stable diffusion autoencoder from Huggingface Hub."""
|
||||
_check_key(key, "load_autoencoder")
|
||||
|
||||
# Download the config and create the model
|
||||
vae_config = _MODELS[key]["vae_config"]
|
||||
with open(hf_hub_download(key, vae_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
model = Autoencoder(
|
||||
AutoencoderConfig(
|
||||
in_channels=config["in_channels"],
|
||||
out_channels=config["out_channels"],
|
||||
latent_channels_out=2 * config["latent_channels"],
|
||||
latent_channels_in=config["latent_channels"],
|
||||
block_out_channels=config["block_out_channels"],
|
||||
layers_per_block=config["layers_per_block"],
|
||||
norm_num_groups=config["norm_num_groups"],
|
||||
)
|
||||
)
|
||||
|
||||
# Download the weights and map them into the model
|
||||
vae_weights = _MODELS[key]["vae"]
|
||||
weight_file = hf_hub_download(key, vae_weights)
|
||||
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
||||
"""Load the stable diffusion config from Huggingface Hub."""
|
||||
_check_key(key, "load_diffusion_config")
|
||||
|
||||
diffusion_config = _MODELS[key]["diffusion_config"]
|
||||
with open(hf_hub_download(key, diffusion_config)) as f:
|
||||
config = json.load(f)
|
||||
|
||||
return DiffusionConfig(
|
||||
beta_start=config["beta_start"],
|
||||
beta_end=config["beta_end"],
|
||||
beta_schedule=config["beta_schedule"],
|
||||
num_train_steps=config["num_train_timesteps"],
|
||||
)
|
||||
|
||||
|
||||
def load_tokenizer(key: str = _DEFAULT_MODEL):
|
||||
_check_key(key, "load_tokenizer")
|
||||
|
||||
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
|
||||
with open(vocab_file, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
|
||||
with open(merges_file, encoding="utf-8") as f:
|
||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
|
||||
|
||||
return Tokenizer(bpe_ranks, vocab)
|
70
stable_diffusion/stable_diffusion/sampler.py
Normal file
70
stable_diffusion/stable_diffusion/sampler.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from .config import DiffusionConfig
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
|
||||
def _linspace(a, b, num):
|
||||
x = mx.arange(0, num) / (num - 1)
|
||||
return (b - a) * x + a
|
||||
|
||||
|
||||
def _interp(y, x_new):
|
||||
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
|
||||
x_low = x_new.astype(mx.int32)
|
||||
x_high = mx.minimum(x_low + 1, len(y) - 1)
|
||||
|
||||
y_low = y[x_low]
|
||||
y_high = y[x_high]
|
||||
delta_x = x_new - x_low
|
||||
y_new = y_low * (1 - delta_x) + delta_x * y_high
|
||||
|
||||
return y_new
|
||||
|
||||
|
||||
class SimpleEulerSampler:
|
||||
"""A simple Euler integrator that can be used to sample from our diffusion models.
|
||||
|
||||
The method ``step()`` performs one Euler step from x_t to x_t_prev.
|
||||
"""
|
||||
|
||||
def __init__(self, config: DiffusionConfig):
|
||||
# Compute the noise schedule
|
||||
if config.beta_schedule == "linear":
|
||||
betas = _linspace(
|
||||
config.beta_start, config.beta_end, config.num_train_steps
|
||||
)
|
||||
elif config.beta_schedule == "scaled_linear":
|
||||
betas = _linspace(
|
||||
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
|
||||
).square()
|
||||
else:
|
||||
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
|
||||
|
||||
alphas = 1 - betas
|
||||
alphas_cumprod = mx.cumprod(alphas)
|
||||
|
||||
self._sigmas = mx.concatenate(
|
||||
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
|
||||
)
|
||||
|
||||
def sample_prior(self, shape, dtype=mx.float32, key=None):
|
||||
noise = mx.random.normal(shape, key=key)
|
||||
return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype)
|
||||
|
||||
def sigmas(self, t):
|
||||
return _interp(self._sigmas, t)
|
||||
|
||||
def timesteps(self, num_steps: int, dtype=mx.float32):
|
||||
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
|
||||
return list(zip(steps, steps[1:]))
|
||||
|
||||
def step(self, eps_pred, x_t, t, t_prev):
|
||||
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||
|
||||
dt = sigma_prev - sigma
|
||||
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
|
||||
|
||||
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||
|
||||
return x_t_prev
|
95
stable_diffusion/stable_diffusion/tokenizer.py
Normal file
95
stable_diffusion/stable_diffusion/tokenizer.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import regex
|
||||
|
||||
|
||||
class Tokenizer:
|
||||
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
|
||||
def __init__(self, bpe_ranks, vocab):
|
||||
self.bpe_ranks = bpe_ranks
|
||||
self.vocab = vocab
|
||||
self.pat = regex.compile(
|
||||
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
|
||||
regex.IGNORECASE,
|
||||
)
|
||||
|
||||
self._cache = {self.bos: self.bos, self.eos: self.eos}
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
return "<|startoftext|>"
|
||||
|
||||
@property
|
||||
def bos_token(self):
|
||||
return self.vocab[self.bos]
|
||||
|
||||
@property
|
||||
def eos(self):
|
||||
return "<|endoftext|>"
|
||||
|
||||
@property
|
||||
def eos_token(self):
|
||||
return self.vocab[self.eos]
|
||||
|
||||
def bpe(self, text):
|
||||
if text in self._cache:
|
||||
return self._cache[text]
|
||||
|
||||
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
if not unique_bigrams:
|
||||
return unigrams
|
||||
|
||||
# In every iteration try to merge the two most likely bigrams. If none
|
||||
# was merged we are done.
|
||||
#
|
||||
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
|
||||
while unique_bigrams:
|
||||
bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
|
||||
if bigram not in self.bpe_ranks:
|
||||
break
|
||||
|
||||
new_unigrams = []
|
||||
skip = False
|
||||
for a, b in zip(unigrams, unigrams[1:]):
|
||||
if skip:
|
||||
skip = False
|
||||
continue
|
||||
|
||||
if (a, b) == bigram:
|
||||
new_unigrams.append(a + b)
|
||||
skip = True
|
||||
|
||||
else:
|
||||
new_unigrams.append(a)
|
||||
|
||||
if not skip:
|
||||
new_unigrams.append(b)
|
||||
|
||||
unigrams = new_unigrams
|
||||
unique_bigrams = set(zip(unigrams, unigrams[1:]))
|
||||
|
||||
self._cache[text] = unigrams
|
||||
|
||||
return unigrams
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
|
||||
# Lower case cleanup and split according to self.pat. Huggingface does
|
||||
# a much more thorough job here but this should suffice for 95% of
|
||||
# cases.
|
||||
clean_text = regex.sub(r"\s+", " ", text.lower())
|
||||
tokens = regex.findall(self.pat, clean_text)
|
||||
|
||||
# Split the tokens according to the byte-pair merge file
|
||||
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
|
||||
|
||||
# Map to token ids and return
|
||||
tokens = [self.vocab[t] for t in bpe_tokens]
|
||||
if prepend_bos:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos:
|
||||
tokens.append(self.eos_token)
|
||||
|
||||
return tokens
|
423
stable_diffusion/stable_diffusion/unet.py
Normal file
423
stable_diffusion/stable_diffusion/unet.py
Normal file
@@ -0,0 +1,423 @@
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import UNetConfig
|
||||
|
||||
|
||||
def upsample_nearest(x, scale: int = 2):
|
||||
B, H, W, C = x.shape
|
||||
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
|
||||
x = x.reshape(B, H * scale, W * scale, C)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, in_channels: int, time_embed_dim: int):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
|
||||
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.linear_1(x)
|
||||
x = nn.silu(x)
|
||||
x = self.linear_2(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
model_dims: int,
|
||||
num_heads: int,
|
||||
hidden_dims: Optional[int] = None,
|
||||
memory_dims: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.LayerNorm(model_dims)
|
||||
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
|
||||
self.attn1.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
memory_dims = memory_dims or model_dims
|
||||
self.norm2 = nn.LayerNorm(model_dims)
|
||||
self.attn2 = nn.MultiHeadAttention(
|
||||
model_dims, num_heads, key_input_dims=memory_dims
|
||||
)
|
||||
self.attn2.out_proj.bias = mx.zeros(model_dims)
|
||||
|
||||
hidden_dims = hidden_dims or 4 * model_dims
|
||||
self.norm3 = nn.LayerNorm(model_dims)
|
||||
self.linear1 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear2 = nn.Linear(model_dims, hidden_dims)
|
||||
self.linear3 = nn.Linear(hidden_dims, model_dims)
|
||||
|
||||
def __call__(self, x, memory, attn_mask, memory_mask):
|
||||
# Self attention
|
||||
y = self.norm1(x)
|
||||
y = self.attn1(y, y, y, attn_mask)
|
||||
x = x + y
|
||||
|
||||
# Cross attention
|
||||
y = self.norm2(x)
|
||||
y = self.attn2(y, memory, memory, memory_mask)
|
||||
x = x + y
|
||||
|
||||
# FFN
|
||||
y = self.norm3(x)
|
||||
y_a = self.linear1(y)
|
||||
y_b = self.linear2(y)
|
||||
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
|
||||
y = self.linear3(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Transformer2D(nn.Module):
|
||||
"""A transformer model for inputs with 2 spatial dimensions."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
model_dims: int,
|
||||
encoder_dims: int,
|
||||
num_heads: int,
|
||||
num_layers: int = 1,
|
||||
norm_num_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
|
||||
self.proj_in = nn.Linear(in_channels, model_dims)
|
||||
self.transformer_blocks = [
|
||||
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
self.proj_out = nn.Linear(model_dims, in_channels)
|
||||
|
||||
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||
# Save the input to add to the output
|
||||
input_x = x
|
||||
|
||||
# Perform the input norm and projection
|
||||
B, H, W, C = x.shape
|
||||
x = self.norm(x).reshape(B, -1, C)
|
||||
x = self.proj_in(x)
|
||||
|
||||
# Apply the transformer
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
# Apply the output projection and reshape
|
||||
x = self.proj_out(x)
|
||||
x = x.reshape(B, H, W, C)
|
||||
|
||||
return x + input_x
|
||||
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: Optional[int] = None,
|
||||
groups: int = 32,
|
||||
temb_channels: Optional[int] = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
out_channels = out_channels or in_channels
|
||||
|
||||
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
|
||||
self.conv1 = nn.Conv2d(
|
||||
in_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
if temb_channels is not None:
|
||||
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
|
||||
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
|
||||
self.conv2 = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||
|
||||
def __call__(self, x, temb=None):
|
||||
if temb is not None:
|
||||
temb = self.time_emb_proj(nn.silu(temb))
|
||||
|
||||
y = self.norm1(x)
|
||||
y = nn.silu(y)
|
||||
y = self.conv1(y)
|
||||
if temb is not None:
|
||||
y = y + temb[:, None, None, :]
|
||||
y = self.norm2(y)
|
||||
y = nn.silu(y)
|
||||
y = self.conv2(y)
|
||||
|
||||
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class UNetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
temb_channels: int,
|
||||
prev_out_channels: Optional[int] = None,
|
||||
num_layers: int = 1,
|
||||
transformer_layers_per_block: int = 1,
|
||||
num_attention_heads: int = 8,
|
||||
cross_attention_dim=1280,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
add_cross_attention=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Prepare the in channels list for the resnets
|
||||
if prev_out_channels is None:
|
||||
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
|
||||
else:
|
||||
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
|
||||
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
|
||||
in_channels_list = [
|
||||
a + b for a, b in zip(in_channels_list, res_channels_list)
|
||||
]
|
||||
|
||||
# Add resnet blocks that also process the time embedding
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=ic,
|
||||
out_channels=out_channels,
|
||||
temb_channels=temb_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for ic in in_channels_list
|
||||
]
|
||||
|
||||
# Add optional cross attention layers
|
||||
if add_cross_attention:
|
||||
self.attentions = [
|
||||
Transformer2D(
|
||||
in_channels=out_channels,
|
||||
model_dims=out_channels,
|
||||
num_heads=num_attention_heads,
|
||||
num_layers=transformer_layers_per_block,
|
||||
encoder_dims=cross_attention_dim,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
x,
|
||||
encoder_x=None,
|
||||
temb=None,
|
||||
attn_mask=None,
|
||||
encoder_attn_mask=None,
|
||||
residual_hidden_states=None,
|
||||
):
|
||||
output_states = []
|
||||
|
||||
for i in range(len(self.resnets)):
|
||||
if residual_hidden_states is not None:
|
||||
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
|
||||
|
||||
x = self.resnets[i](x, temb)
|
||||
|
||||
if "attentions" in self:
|
||||
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
|
||||
output_states.append(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = self.downsample(x)
|
||||
output_states.append(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
output_states.append(x)
|
||||
|
||||
return x, output_states
|
||||
|
||||
|
||||
class UNetModel(nn.Module):
|
||||
"""The conditional 2D UNet model that actually performs the denoising."""
|
||||
|
||||
def __init__(self, config: UNetConfig):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
config.in_channels,
|
||||
config.block_out_channels[0],
|
||||
config.conv_in_kernel,
|
||||
padding=(config.conv_in_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
self.timesteps = nn.SinusoidalPositionalEncoding(
|
||||
config.block_out_channels[0],
|
||||
max_freq=1,
|
||||
min_freq=math.exp(
|
||||
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
|
||||
),
|
||||
scale=1.0,
|
||||
cos_first=True,
|
||||
full_turns=False,
|
||||
)
|
||||
self.time_embedding = TimestepEmbedding(
|
||||
config.block_out_channels[0],
|
||||
config.block_out_channels[0] * 4,
|
||||
)
|
||||
|
||||
# Make the downsampling blocks
|
||||
block_channels = [config.block_out_channels[0]] + list(
|
||||
config.block_out_channels
|
||||
)
|
||||
self.down_blocks = [
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
num_layers=config.layers_per_block[i],
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||
add_upsample=False,
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(
|
||||
zip(block_channels, block_channels[1:])
|
||||
)
|
||||
]
|
||||
|
||||
# Make the middle block
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
Transformer2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
model_dims=config.block_out_channels[-1],
|
||||
num_heads=config.num_attention_heads[-1],
|
||||
num_layers=config.transformer_layers_per_block[-1],
|
||||
encoder_dims=config.cross_attention_dim[-1],
|
||||
),
|
||||
ResnetBlock2D(
|
||||
in_channels=config.block_out_channels[-1],
|
||||
out_channels=config.block_out_channels[-1],
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
groups=config.norm_num_groups,
|
||||
),
|
||||
]
|
||||
|
||||
# Make the upsampling blocks
|
||||
block_channels = (
|
||||
[config.block_out_channels[0]]
|
||||
+ list(config.block_out_channels)
|
||||
+ [config.block_out_channels[-1]]
|
||||
)
|
||||
self.up_blocks = [
|
||||
UNetBlock2D(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
temb_channels=config.block_out_channels[0] * 4,
|
||||
prev_out_channels=prev_out_channels,
|
||||
num_layers=config.layers_per_block[i] + 1,
|
||||
transformer_layers_per_block=config.transformer_layers_per_block[i],
|
||||
num_attention_heads=config.num_attention_heads[i],
|
||||
cross_attention_dim=config.cross_attention_dim[i],
|
||||
resnet_groups=config.norm_num_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=(i > 0),
|
||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
||||
)
|
||||
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
||||
list(
|
||||
enumerate(
|
||||
zip(block_channels, block_channels[1:], block_channels[2:])
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
config.norm_num_groups,
|
||||
config.block_out_channels[0],
|
||||
pytorch_compatible=True,
|
||||
)
|
||||
self.conv_out = nn.Conv2d(
|
||||
config.block_out_channels[0],
|
||||
config.out_channels,
|
||||
config.conv_out_kernel,
|
||||
padding=(config.conv_out_kernel - 1) // 2,
|
||||
)
|
||||
|
||||
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
|
||||
|
||||
# Compute the time embeddings
|
||||
temb = self.timesteps(timestep).astype(x.dtype)
|
||||
temb = self.time_embedding(temb)
|
||||
|
||||
# Preprocess the input
|
||||
x = self.conv_in(x)
|
||||
|
||||
# Run the downsampling part of the unet
|
||||
residuals = [x]
|
||||
for block in self.down_blocks:
|
||||
x, res = block(
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
)
|
||||
residuals.extend(res)
|
||||
|
||||
# Run the middle part of the unet
|
||||
x = self.mid_blocks[0](x, temb)
|
||||
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
|
||||
x = self.mid_blocks[2](x, temb)
|
||||
|
||||
# Run the upsampling part of the unet
|
||||
for block in self.up_blocks:
|
||||
x, _ = block(
|
||||
x,
|
||||
encoder_x=encoder_x,
|
||||
temb=temb,
|
||||
attn_mask=attn_mask,
|
||||
encoder_attn_mask=encoder_attn_mask,
|
||||
residual_hidden_states=residuals,
|
||||
)
|
||||
|
||||
# Postprocess the output
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
266
stable_diffusion/stable_diffusion/vae.py
Normal file
266
stable_diffusion/stable_diffusion/vae.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import math
|
||||
from typing import List
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
|
||||
from .config import AutoencoderConfig
|
||||
from .unet import ResnetBlock2D, upsample_nearest
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""A single head unmasked attention for use with the VAE."""
|
||||
|
||||
def __init__(self, dims: int, norm_groups: int = 32):
|
||||
super().__init__()
|
||||
|
||||
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
|
||||
self.query_proj = nn.Linear(dims, dims)
|
||||
self.key_proj = nn.Linear(dims, dims)
|
||||
self.value_proj = nn.Linear(dims, dims)
|
||||
self.out_proj = nn.Linear(dims, dims)
|
||||
|
||||
def __call__(self, x):
|
||||
B, H, W, C = x.shape
|
||||
|
||||
y = self.group_norm(x)
|
||||
|
||||
queries = self.query_proj(y).reshape(B, H * W, C)
|
||||
keys = self.key_proj(y).reshape(B, H * W, C)
|
||||
values = self.value_proj(y).reshape(B, H * W, C)
|
||||
|
||||
scale = 1 / math.sqrt(queries.shape[-1])
|
||||
scores = (queries * scale) @ keys.transpose(0, 2, 1)
|
||||
attn = mx.softmax(scores, axis=-1)
|
||||
y = (attn @ values).reshape(B, H, W, C)
|
||||
|
||||
y = self.out_proj(y)
|
||||
x = x + y
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class EncoderDecoderBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
num_layers: int = 1,
|
||||
resnet_groups: int = 32,
|
||||
add_downsample=True,
|
||||
add_upsample=True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
# Add the resnet blocks
|
||||
self.resnets = [
|
||||
ResnetBlock2D(
|
||||
in_channels=in_channels if i == 0 else out_channels,
|
||||
out_channels=out_channels,
|
||||
groups=resnet_groups,
|
||||
)
|
||||
for i in range(num_layers)
|
||||
]
|
||||
|
||||
# Add an optional downsampling layer
|
||||
if add_downsample:
|
||||
self.downsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=2, padding=1
|
||||
)
|
||||
|
||||
# or upsampling layer
|
||||
if add_upsample:
|
||||
self.upsample = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
def __call__(self, x):
|
||||
for resnet in self.resnets:
|
||||
x = resnet(x)
|
||||
|
||||
if "downsample" in self:
|
||||
x = self.downsample(x)
|
||||
|
||||
if "upsample" in self:
|
||||
x = self.upsample(upsample_nearest(x))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
"""Implements the encoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
channels = [block_out_channels[0]] + list(block_out_channels)
|
||||
self.down_blocks = [
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=i < len(block_out_channels) - 1,
|
||||
add_upsample=False,
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||
]
|
||||
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[-1], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
for l in self.down_blocks:
|
||||
x = l(x)
|
||||
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Decoder(nn.Module):
|
||||
"""Implements the decoder side of the Autoencoder."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
block_out_channels: List[int] = [64],
|
||||
layers_per_block: int = 2,
|
||||
resnet_groups: int = 32,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
|
||||
)
|
||||
|
||||
self.mid_blocks = [
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
Attention(block_out_channels[-1], resnet_groups),
|
||||
ResnetBlock2D(
|
||||
in_channels=block_out_channels[-1],
|
||||
out_channels=block_out_channels[-1],
|
||||
groups=resnet_groups,
|
||||
),
|
||||
]
|
||||
|
||||
channels = list(reversed(block_out_channels))
|
||||
channels = [channels[0]] + channels
|
||||
self.up_blocks = [
|
||||
EncoderDecoderBlock2D(
|
||||
in_channels,
|
||||
out_channels,
|
||||
num_layers=layers_per_block,
|
||||
resnet_groups=resnet_groups,
|
||||
add_downsample=False,
|
||||
add_upsample=i < len(block_out_channels) - 1,
|
||||
)
|
||||
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
|
||||
]
|
||||
|
||||
self.conv_norm_out = nn.GroupNorm(
|
||||
resnet_groups, block_out_channels[0], pytorch_compatible=True
|
||||
)
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
def __call__(self, x):
|
||||
x = self.conv_in(x)
|
||||
|
||||
x = self.mid_blocks[0](x)
|
||||
x = self.mid_blocks[1](x)
|
||||
x = self.mid_blocks[2](x)
|
||||
|
||||
for l in self.up_blocks:
|
||||
x = l(x)
|
||||
|
||||
x = self.conv_norm_out(x)
|
||||
x = nn.silu(x)
|
||||
x = self.conv_out(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Autoencoder(nn.Module):
|
||||
"""The autoencoder that allows us to perform diffusion in the latent space."""
|
||||
|
||||
def __init__(self, config: AutoencoderConfig):
|
||||
super().__init__()
|
||||
|
||||
self.latent_channels = config.latent_channels_in
|
||||
self.scaling_factor = config.scaling_factor
|
||||
self.encoder = Encoder(
|
||||
config.in_channels,
|
||||
config.latent_channels_out,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
self.decoder = Decoder(
|
||||
config.latent_channels_in,
|
||||
config.out_channels,
|
||||
config.block_out_channels,
|
||||
config.layers_per_block + 1,
|
||||
resnet_groups=config.norm_num_groups,
|
||||
)
|
||||
|
||||
self.quant_proj = nn.Linear(
|
||||
config.latent_channels_out, config.latent_channels_out
|
||||
)
|
||||
self.post_quant_proj = nn.Linear(
|
||||
config.latent_channels_in, config.latent_channels_in
|
||||
)
|
||||
|
||||
def decode(self, z):
|
||||
return self.decoder(self.post_quant_proj(z))
|
||||
|
||||
def __call__(self, x, key=None):
|
||||
x = self.encoder(x)
|
||||
x = self.query_proj(x)
|
||||
|
||||
mean, logvar = x.split(2, axis=-1)
|
||||
std = mx.exp(0.5 * logvar)
|
||||
z = mx.random.normal(mean.shape, key=key) * std + mean
|
||||
|
||||
x_hat = self.decode(z)
|
||||
|
||||
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)
|
55
stable_diffusion/txt2image.py
Normal file
55
stable_diffusion/txt2image.py
Normal file
@@ -0,0 +1,55 @@
|
||||
import argparse
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
|
||||
import mlx.core as mx
|
||||
|
||||
from stable_diffusion import StableDiffusion
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using stable diffusion"
|
||||
)
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--n_images", type=int, default=4)
|
||||
parser.add_argument("--steps", type=int, default=50)
|
||||
parser.add_argument("--cfg", type=float, default=7.5)
|
||||
parser.add_argument("--negative_prompt", default="")
|
||||
parser.add_argument("--n_rows", type=int, default=1)
|
||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||
parser.add_argument("--output", default="out.png")
|
||||
args = parser.parse_args()
|
||||
|
||||
sd = StableDiffusion()
|
||||
|
||||
# Generate the latent vectors using diffusion
|
||||
latents = sd.generate_latents(
|
||||
args.prompt,
|
||||
n_images=args.n_images,
|
||||
cfg_weight=args.cfg,
|
||||
num_steps=args.steps,
|
||||
negative_text=args.negative_prompt,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.steps):
|
||||
mx.simplify(x_t)
|
||||
mx.simplify(x_t)
|
||||
mx.eval(x_t)
|
||||
|
||||
# Decode them into images
|
||||
decoded = []
|
||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
||||
mx.eval(decoded[-1])
|
||||
|
||||
# Arrange them on a grid
|
||||
x = mx.concatenate(decoded, axis=0)
|
||||
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
|
||||
B, H, W, C = x.shape
|
||||
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
|
||||
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
|
||||
x = (x * 255).astype(mx.uint8)
|
||||
|
||||
# Save them to disc
|
||||
im = Image.fromarray(x.__array__())
|
||||
im.save(args.output)
|
Reference in New Issue
Block a user