Add the Llama and Stable Diffusion examples

This commit is contained in:
Angelos Katharopoulos
2023-11-29 10:38:20 -08:00
parent 7f1328f333
commit b364cc56cd
17 changed files with 1916 additions and 0 deletions

View File

@@ -0,0 +1,95 @@
Stable Diffusion
================
Stable Diffusion in MLX. The implementation was ported from Hugginface's
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
and using the weights available on the Huggingface Hub by Stability AI at
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
![out](generated-mlx.png)
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign saying MLX in capital letters.'*
Installation
------------
The dependencies are minimal, namely:
- `safetensors` and `huggingface-hub` to load the checkpoints.
- `regex` for the tokenization
- `numpy` because safetensors needs to return some form of array
- `tqdm` and `PIL` for the `txt2image.py` script
You can install all of the above with the `requirements.txt` as follows:
pip install -r requirements.txt
Usage
------
Although each component in this repository can be used by itsself, the fastest
way to get started is by using the `StableDiffusion` class from the `diffusion`
module.
```python
from stable_diffusion import StableDiffusion
# This will download all the weights from HF hub and load the models in
# memory
sd = StableDiffusion()
# This creates a python generator that returns the latent produced by the
# reverse diffusion process.
#
# Because MLX is lazily evaluated iterating over this generator doesn't
# actually perform the computation until mx.eval() is called.
latent_generator = sd.generate_latents("A photo of an astronaut riding a horse on Mars.")
# Here we are evaluating each diffusion step but we could also evaluate
# once at the end.
for x_t in latent_generator:
mx.simplify(x_t) # remove possible redundant computation eg reuse
# scalars etc
mx.eval(x_t)
# Now x_t is the last latent from the reverse process aka x_0. We can
# decode it into an image using the stable diffusion VAE.
im = sd.decode(x_t)
```
The above is almost line for line the implementation of the `txt2image.py`
script in the root of the repository. You can use the script as follows:
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
Performance
-----------
The following table compares the performance of the UNet in stable diffusion.
We report throughput in images per second for the provided `txt2image.py`
script and the `diffusers` library using the MPS PyTorch backend.
At the time of writing this comparison convolutions are still some of the least
optimized operations in MLX. Despite that, MLX still achieves **~40% higher
throughput** than PyTorch with a batch size of 16 and ~15% higher when
comparing the optimal batch sizes.
Notably, PyTorch achieves almost ~50% higher throughput for the batch size of 1
which is unfortunate as that means that a single image can be computed faster.
However, when starting with the models not loaded in memory and PyTorch's MPS
graph kernels not cached, the compilation time more than accounts for this
speed difference.
| Batch size | PyTorch | MLX |
| ---------- | ----------- | ----------- |
| 1 | 6.25 im/s | 4.17 im/s |
| 2 | 7.14 im/s | 5.88 im/s |
| 4 |**7.69 im/s**| 7.14 im/s |
| 6 | 7.22 im/s | 8.00 im/s |
| 8 | 6.89 im/s | 8.42 im/s |
| 12 | 6.62 im/s | 8.51 im/s |
| 16 | 6.32 im/s |**8.79 im/s**|
The above experiments were made on an M2 Ultra with PyTorch version 2.1,
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
used classifier free guidance which means that the above batch sizes result
double the images processed by the UNet.

Binary file not shown.

After

Width:  |  Height:  |  Size: 185 KiB

View File

@@ -0,0 +1,6 @@
safetensors
huggingface-hub
regex
numpy
tqdm
Pillow

View File

@@ -0,0 +1,96 @@
import time
from typing import Tuple
import mlx.core as mx
from .model_io import (
load_unet,
load_text_encoder,
load_autoencoder,
load_diffusion_config,
load_tokenizer,
_DEFAULT_MODEL,
)
from .sampler import SimpleEulerSampler
def _repeat(x, n, axis):
# Make the expanded shape
s = x.shape
s.insert(axis + 1, n)
# Expand
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
# Make the flattened shape
s.pop(axis + 1)
s[axis] *= n
return x.reshape(s)
class StableDiffusion:
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
self.dtype = mx.float16 if float16 else mx.float32
self.diffusion_config = load_diffusion_config(model)
self.unet = load_unet(model, float16)
self.text_encoder = load_text_encoder(model, float16)
self.autoencoder = load_autoencoder(model, float16)
self.sampler = SimpleEulerSampler(self.diffusion_config)
self.tokenizer = load_tokenizer(model)
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Tokenize the text
tokens = [self.tokenizer.tokenize(text)]
if cfg_weight > 1:
tokens += [self.tokenizer.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
tokens = mx.array(tokens)
# Compute the features
conditioning = self.text_encoder(tokens)
# Repeat the conditioning for each of the generated images
if n_images > 1:
conditioning = _repeat(conditioning, n_images, axis=0)
# Create the latent variables
x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels),
dtype=self.dtype
)
# Perform the denoising loop
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
x_t = x_t_prev
yield x_t
def decode(self, x_t):
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
return x

View File

@@ -0,0 +1,68 @@
import mlx.core as mx
import mlx.nn as nn
from .config import CLIPTextModelConfig
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
# Add biases to the attention projections to match CLIP
self.attention.query_proj.bias = mx.zeros(model_dims)
self.attention.key_proj.bias = mx.zeros(model_dims)
self.attention.value_proj.bias = mx.zeros(model_dims)
self.attention.out_proj.bias = mx.zeros(model_dims)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = nn.gelu_approx(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
def __call__(self, x):
# Extract some shapes
B, N = x.shape
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
for l in self.layers:
x = l(x, mask)
# Apply the final layernorm and return
return self.final_layer_norm(x)

View File

@@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class AutoencoderConfig:
in_channels: int = 3
out_channels: int = 3
latent_channels_out: int = 8
latent_channels_in: int = 4
block_out_channels: Tuple[int] = (128, 256, 512, 512)
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.18215
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
@dataclass
class UNetConfig:
in_channels: int = 4
out_channels: int = 4
conv_in_kernel: int = 3
conv_out_kernel: int = 3
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: Tuple[int] = (2, 2, 2, 2)
mid_block_layers: int = 2
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
@dataclass
class DiffusionConfig:
beta_schedule: str = "scaled_linear"
beta_start: float = 0.00085
beta_end: float = 0.012
num_train_steps: int = 1000

View File

@@ -0,0 +1,284 @@
import json
from functools import partial
import numpy as np
from huggingface_hub import hf_hub_download
from safetensors import safe_open as safetensor_open
import mlx.core as mx
from mlx.utils import tree_unflatten
from .clip import CLIPTextModel
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
from .tokenizer import Tokenizer
from .unet import UNetModel
from .vae import Autoencoder
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
_MODELS = {
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
"stabilityai/stable-diffusion-2-1-base": {
"unet_config": "unet/config.json",
"unet": "unet/diffusion_pytorch_model.safetensors",
"text_encoder_config": "text_encoder/config.json",
"text_encoder": "text_encoder/model.safetensors",
"vae_config": "vae/config.json",
"vae": "vae/diffusion_pytorch_model.safetensors",
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
}
}
def _from_numpy(x):
return mx.array(np.ascontiguousarray(x))
def map_unet_weights(key, value):
# Map up/downsampling
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map attention layers
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
key = key.replace("to_out.0", "out_proj")
if "to_q" in key:
key = key.replace("to_q", "query_proj")
if "to_v" in key:
key = key.replace("to_v", "value_proj")
# Map transformer ffn
if "ff.net.2" in key:
key = key.replace("ff.net.2", "linear3")
if "ff.net.0" in key:
k1 = key.replace("ff.net.0.proj", "linear1")
k2 = key.replace("ff.net.0.proj", "linear2")
v1, v2 = np.split(value, 2)
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
if "conv_shortcut.weight" in key:
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
return [(key, _from_numpy(value))]
def map_clip_text_encoder_weights(key, value):
# Remove prefixes
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
return [(key, _from_numpy(value))]
def map_vae_weights(key, value):
# Map up/downsampling
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map attention layers
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
key = key.replace("to_out.0", "out_proj")
if "to_q" in key:
key = key.replace("to_q", "query_proj")
if "to_v" in key:
key = key.replace("to_v", "value_proj")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map the quant/post_quant layers
if "quant_conv" in key:
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
# Map the conv_shortcut to linear
if "conv_shortcut.weight" in key:
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
return [(key, _from_numpy(value))]
def _flatten(params):
return [(k, v) for p in params for (k, v) in p]
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
dtype = np.float16 if float16 else np.float32
with safetensor_open(weight_file, framework="numpy") as f:
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
model.update(tree_unflatten(weights))
def _check_key(key: str, part: str):
if key not in _MODELS:
raise ValueError(
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
)
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion UNet from Huggingface Hub."""
_check_key(key, "load_unet")
# Download the config and create the model
unet_config = _MODELS[key]["unet_config"]
with open(hf_hub_download(key, unet_config)) as f:
config = json.load(f)
n_blocks = len(config["block_out_channels"])
model = UNetModel(
UNetConfig(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks,
num_attention_heads=config["attention_head_dim"],
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
)
)
# Download the weights and map them into the model
unet_weights = _MODELS[key]["unet"]
weight_file = hf_hub_download(key, unet_weights)
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
return model
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion text encoder from Huggingface Hub."""
_check_key(key, "load_text_encoder")
# Download the config and create the model
text_encoder_config = _MODELS[key]["text_encoder_config"]
with open(hf_hub_download(key, text_encoder_config)) as f:
config = json.load(f)
model = CLIPTextModel(
CLIPTextModelConfig(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
)
)
# Download the weights and map them into the model
text_encoder_weights = _MODELS[key]["text_encoder"]
weight_file = hf_hub_download(key, text_encoder_weights)
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
return model
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion autoencoder from Huggingface Hub."""
_check_key(key, "load_autoencoder")
# Download the config and create the model
vae_config = _MODELS[key]["vae_config"]
with open(hf_hub_download(key, vae_config)) as f:
config = json.load(f)
model = Autoencoder(
AutoencoderConfig(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
latent_channels_out=2 * config["latent_channels"],
latent_channels_in=config["latent_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=config["layers_per_block"],
norm_num_groups=config["norm_num_groups"],
)
)
# Download the weights and map them into the model
vae_weights = _MODELS[key]["vae"]
weight_file = hf_hub_download(key, vae_weights)
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
return model
def load_diffusion_config(key: str = _DEFAULT_MODEL):
"""Load the stable diffusion config from Huggingface Hub."""
_check_key(key, "load_diffusion_config")
diffusion_config = _MODELS[key]["diffusion_config"]
with open(hf_hub_download(key, diffusion_config)) as f:
config = json.load(f)
return DiffusionConfig(
beta_start=config["beta_start"],
beta_end=config["beta_end"],
beta_schedule=config["beta_schedule"],
num_train_steps=config["num_train_timesteps"],
)
def load_tokenizer(key: str = _DEFAULT_MODEL):
_check_key(key, "load_tokenizer")
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return Tokenizer(bpe_ranks, vocab)

View File

@@ -0,0 +1,70 @@
from .config import DiffusionConfig
import mlx.core as mx
def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1)
return (b - a) * x + a
def _interp(y, x_new):
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
x_low = x_new.astype(mx.int32)
x_high = mx.minimum(x_low + 1, len(y) - 1)
y_low = y[x_low]
y_high = y[x_high]
delta_x = x_new - x_low
y_new = y_low * (1 - delta_x) + delta_x * y_high
return y_new
class SimpleEulerSampler:
"""A simple Euler integrator that can be used to sample from our diffusion models.
The method ``step()`` performs one Euler step from x_t to x_t_prev.
"""
def __init__(self, config: DiffusionConfig):
# Compute the noise schedule
if config.beta_schedule == "linear":
betas = _linspace(
config.beta_start, config.beta_end, config.num_train_steps
)
elif config.beta_schedule == "scaled_linear":
betas = _linspace(
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
).square()
else:
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
alphas = 1 - betas
alphas_cumprod = mx.cumprod(alphas)
self._sigmas = mx.concatenate(
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype)
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, dtype=mx.float32):
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
dt = sigma_prev - sigma
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev

View File

@@ -0,0 +1,95 @@
import regex
class Tokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Huggingface does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
return tokens

View File

@@ -0,0 +1,423 @@
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .config import UNetConfig
def upsample_nearest(x, scale: int = 2):
B, H, W, C = x.shape
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
x = x.reshape(B, H * scale, W * scale, C)
return x
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def __call__(self, x):
x = self.linear_1(x)
x = nn.silu(x)
x = self.linear_2(x)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
model_dims: int,
num_heads: int,
hidden_dims: Optional[int] = None,
memory_dims: Optional[int] = None,
):
super().__init__()
self.norm1 = nn.LayerNorm(model_dims)
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
self.attn1.out_proj.bias = mx.zeros(model_dims)
memory_dims = memory_dims or model_dims
self.norm2 = nn.LayerNorm(model_dims)
self.attn2 = nn.MultiHeadAttention(
model_dims, num_heads, key_input_dims=memory_dims
)
self.attn2.out_proj.bias = mx.zeros(model_dims)
hidden_dims = hidden_dims or 4 * model_dims
self.norm3 = nn.LayerNorm(model_dims)
self.linear1 = nn.Linear(model_dims, hidden_dims)
self.linear2 = nn.Linear(model_dims, hidden_dims)
self.linear3 = nn.Linear(hidden_dims, model_dims)
def __call__(self, x, memory, attn_mask, memory_mask):
# Self attention
y = self.norm1(x)
y = self.attn1(y, y, y, attn_mask)
x = x + y
# Cross attention
y = self.norm2(x)
y = self.attn2(y, memory, memory, memory_mask)
x = x + y
# FFN
y = self.norm3(x)
y_a = self.linear1(y)
y_b = self.linear2(y)
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
y = self.linear3(y)
x = x + y
return x
class Transformer2D(nn.Module):
"""A transformer model for inputs with 2 spatial dimensions."""
def __init__(
self,
in_channels: int,
model_dims: int,
encoder_dims: int,
num_heads: int,
num_layers: int = 1,
norm_num_groups: int = 32,
):
super().__init__()
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
self.proj_in = nn.Linear(in_channels, model_dims)
self.transformer_blocks = [
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
for i in range(num_layers)
]
self.proj_out = nn.Linear(model_dims, in_channels)
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
# Save the input to add to the output
input_x = x
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
for block in self.transformer_blocks:
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
# Apply the output projection and reshape
x = self.proj_out(x)
x = x.reshape(B, H, W, C)
return x + input_x
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
groups: int = 32,
temb_channels: Optional[int] = None,
):
super().__init__()
out_channels = out_channels or in_channels
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if in_channels != out_channels:
self.conv_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x, temb=None):
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y)
y = nn.silu(y)
y = self.conv2(y)
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
return x
class UNetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
prev_out_channels: Optional[int] = None,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
num_attention_heads: int = 8,
cross_attention_dim=1280,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
add_cross_attention=True,
):
super().__init__()
# Prepare the in channels list for the resnets
if prev_out_channels is None:
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
else:
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
in_channels_list = [
a + b for a, b in zip(in_channels_list, res_channels_list)
]
# Add resnet blocks that also process the time embedding
self.resnets = [
ResnetBlock2D(
in_channels=ic,
out_channels=out_channels,
temb_channels=temb_channels,
groups=resnet_groups,
)
for ic in in_channels_list
]
# Add optional cross attention layers
if add_cross_attention:
self.attentions = [
Transformer2D(
in_channels=out_channels,
model_dims=out_channels,
num_heads=num_attention_heads,
num_layers=transformer_layers_per_block,
encoder_dims=cross_attention_dim,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(
self,
x,
encoder_x=None,
temb=None,
attn_mask=None,
encoder_attn_mask=None,
residual_hidden_states=None,
):
output_states = []
for i in range(len(self.resnets)):
if residual_hidden_states is not None:
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
x = self.resnets[i](x, temb)
if "attentions" in self:
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
output_states.append(x)
if "downsample" in self:
x = self.downsample(x)
output_states.append(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
output_states.append(x)
return x, output_states
class UNetModel(nn.Module):
"""The conditional 2D UNet model that actually performs the denoising."""
def __init__(self, config: UNetConfig):
super().__init__()
self.conv_in = nn.Conv2d(
config.in_channels,
config.block_out_channels[0],
config.conv_in_kernel,
padding=(config.conv_in_kernel - 1) // 2,
)
self.timesteps = nn.SinusoidalPositionalEncoding(
config.block_out_channels[0],
max_freq=1,
min_freq=math.exp(
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.time_embedding = TimestepEmbedding(
config.block_out_channels[0],
config.block_out_channels[0] * 4,
)
# Make the downsampling blocks
block_channels = [config.block_out_channels[0]] + list(
config.block_out_channels
)
self.down_blocks = [
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
num_layers=config.layers_per_block[i],
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention=(i < len(config.block_out_channels) - 1),
)
for i, (in_channels, out_channels) in enumerate(
zip(block_channels, block_channels[1:])
)
]
# Make the middle block
self.mid_blocks = [
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
Transformer2D(
in_channels=config.block_out_channels[-1],
model_dims=config.block_out_channels[-1],
num_heads=config.num_attention_heads[-1],
num_layers=config.transformer_layers_per_block[-1],
encoder_dims=config.cross_attention_dim[-1],
),
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
]
# Make the upsampling blocks
block_channels = (
[config.block_out_channels[0]]
+ list(config.block_out_channels)
+ [config.block_out_channels[-1]]
)
self.up_blocks = [
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
prev_out_channels=prev_out_channels,
num_layers=config.layers_per_block[i] + 1,
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention=(i < len(config.block_out_channels) - 1),
)
for i, (in_channels, out_channels, prev_out_channels) in reversed(
list(
enumerate(
zip(block_channels, block_channels[1:], block_channels[2:])
)
)
)
]
self.conv_norm_out = nn.GroupNorm(
config.norm_num_groups,
config.block_out_channels[0],
pytorch_compatible=True,
)
self.conv_out = nn.Conv2d(
config.block_out_channels[0],
config.out_channels,
config.conv_out_kernel,
padding=(config.conv_out_kernel - 1) // 2,
)
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)
# Preprocess the input
x = self.conv_in(x)
# Run the downsampling part of the unet
residuals = [x]
for block in self.down_blocks:
x, res = block(
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
)
residuals.extend(res)
# Run the middle part of the unet
x = self.mid_blocks[0](x, temb)
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
x = self.mid_blocks[2](x, temb)
# Run the upsampling part of the unet
for block in self.up_blocks:
x, _ = block(
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
residual_hidden_states=residuals,
)
# Postprocess the output
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x

View File

@@ -0,0 +1,266 @@
import math
from typing import List
import mlx.core as mx
import mlx.nn as nn
from .config import AutoencoderConfig
from .unet import ResnetBlock2D, upsample_nearest
class Attention(nn.Module):
"""A single head unmasked attention for use with the VAE."""
def __init__(self, dims: int, norm_groups: int = 32):
super().__init__()
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
self.query_proj = nn.Linear(dims, dims)
self.key_proj = nn.Linear(dims, dims)
self.value_proj = nn.Linear(dims, dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x):
B, H, W, C = x.shape
y = self.group_norm(x)
queries = self.query_proj(y).reshape(B, H * W, C)
keys = self.key_proj(y).reshape(B, H * W, C)
values = self.value_proj(y).reshape(B, H * W, C)
scale = 1 / math.sqrt(queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 2, 1)
attn = mx.softmax(scores, axis=-1)
y = (attn @ values).reshape(B, H, W, C)
y = self.out_proj(y)
x = x + y
return x
class EncoderDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
):
super().__init__()
# Add the resnet blocks
self.resnets = [
ResnetBlock2D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
groups=resnet_groups,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x):
for resnet in self.resnets:
x = resnet(x)
if "downsample" in self:
x = self.downsample(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
return x
class Encoder(nn.Module):
"""Implements the encoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
out_channels: int,
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
):
super().__init__()
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
)
channels = [block_out_channels[0]] + list(block_out_channels)
self.down_blocks = [
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=i < len(block_out_channels) - 1,
add_upsample=False,
)
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
]
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[-1], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
for l in self.down_blocks:
x = l(x)
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
"""Implements the decoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
out_channels: int,
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
):
super().__init__()
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
)
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
channels = list(reversed(block_out_channels))
channels = [channels[0]] + channels
self.up_blocks = [
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=False,
add_upsample=i < len(block_out_channels) - 1,
)
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
]
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[0], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
for l in self.up_blocks:
x = l(x)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Autoencoder(nn.Module):
"""The autoencoder that allows us to perform diffusion in the latent space."""
def __init__(self, config: AutoencoderConfig):
super().__init__()
self.latent_channels = config.latent_channels_in
self.scaling_factor = config.scaling_factor
self.encoder = Encoder(
config.in_channels,
config.latent_channels_out,
config.block_out_channels,
config.layers_per_block,
resnet_groups=config.norm_num_groups,
)
self.decoder = Decoder(
config.latent_channels_in,
config.out_channels,
config.block_out_channels,
config.layers_per_block + 1,
resnet_groups=config.norm_num_groups,
)
self.quant_proj = nn.Linear(
config.latent_channels_out, config.latent_channels_out
)
self.post_quant_proj = nn.Linear(
config.latent_channels_in, config.latent_channels_in
)
def decode(self, z):
return self.decoder(self.post_quant_proj(z))
def __call__(self, x, key=None):
x = self.encoder(x)
x = self.query_proj(x)
mean, logvar = x.split(2, axis=-1)
std = mx.exp(0.5 * logvar)
z = mx.random.normal(mean.shape, key=key) * std + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)

View File

@@ -0,0 +1,55 @@
import argparse
from PIL import Image
from tqdm import tqdm
import mlx.core as mx
from stable_diffusion import StableDiffusion
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--n_images", type=int, default=4)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--cfg", type=float, default=7.5)
parser.add_argument("--negative_prompt", default="")
parser.add_argument("--n_rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1)
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
sd = StableDiffusion()
# Generate the latent vectors using diffusion
latents = sd.generate_latents(
args.prompt,
n_images=args.n_images,
cfg_weight=args.cfg,
num_steps=args.steps,
negative_text=args.negative_prompt,
)
for x_t in tqdm(latents, total=args.steps):
mx.simplify(x_t)
mx.simplify(x_t)
mx.eval(x_t)
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
mx.eval(decoded[-1])
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(x.__array__())
im.save(args.output)