Add the Llama and Stable Diffusion examples

This commit is contained in:
Angelos Katharopoulos 2023-11-29 10:38:20 -08:00
parent 7f1328f333
commit b364cc56cd
17 changed files with 1916 additions and 0 deletions

37
llama/README.md Normal file
View File

@ -0,0 +1,37 @@
# LLaMA
An example of generating text with LLaMA using MLX.
LLaMA is a set of open source language models from Meta AI Research[^1] ranging from 7B to 65B parameters.
### Setup
Install the dependencies:
```
pip install -r requirements.txt
```
Next, download and convert the model. If you do not have access to the model
weights you will need to [request
access](https://docs.google.com/forms/d/e/1FAIpQLSfqNECQnMkycAp2jP4Z9TFX0cGR4uf7b_fBxjY_OjhJILlKGA/viewform)
from Meta.
Convert the weights with:
```
python convert.py <path_to_torch_weights> mlx_llama_weights.npz
```
### Run
Once you've converted the weights to MLX format, you can interact with the
LLaMA model:
```
python llama.py mlx_llama.npz tokenizer.model "hello"
```
Run `python llama.py --help` for more details.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2302.13971) and [blog post](https://ai.meta.com/blog/large-language-model-llama-meta-ai/) for more details.

46
llama/convert.py Normal file
View File

@ -0,0 +1,46 @@
import argparse
from itertools import starmap
import numpy as np
import torch
def map_torch_to_mlx(key, value):
if "tok_embedding" in key:
key = "embedding.weight"
elif "norm" in key:
key = key.replace("attention_norm", "norm1").replace("ffn_norm", "norm2")
elif "wq" in key or "wk" in key or "wv" in key or "wo" in key:
key = key.replace("wq", "query_proj")
key = key.replace("wk", "key_proj")
key = key.replace("wv", "value_proj")
key = key.replace("wo", "out_proj")
elif "w1" in key or "w2" in key or "w3" in key:
# The FFN is a separate submodule in PyTorch
key = key.replace("feed_forward.w1", "linear1")
key = key.replace("feed_forward.w3", "linear2")
key = key.replace("feed_forward.w2", "linear3")
elif "output" in key:
key = key.replace("output", "out_proj")
elif "rope" in key:
return None, None
return key, value.numpy()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Convert Llama weights to MLX")
parser.add_argument("torch_weights")
parser.add_argument("output_file")
args = parser.parse_args()
state = torch.load(args.torch_weights)
np.savez(
args.output_file,
**{k: v for k, v in starmap(map_torch_to_mlx, state.items()) if k is not None}
)

304
llama/llama.py Normal file
View File

@ -0,0 +1,304 @@
import argparse
import math
import numpy as np
from sentencepiece import SentencePieceProcessor
import time
import mlx.core as mx
import mlx.nn as nn
from mlx.utils import tree_unflatten
class LlamaAttention(nn.Module):
def __init__(self, dims: int, num_heads: int):
super().__init__()
self.num_heads = num_heads
self.rope = nn.RoPE(dims // num_heads, traditional=True)
self.query_proj = nn.Linear(dims, dims, bias=False)
self.key_proj = nn.Linear(dims, dims, bias=False)
self.value_proj = nn.Linear(dims, dims, bias=False)
self.out_proj = nn.Linear(dims, dims, bias=False)
def __call__(self, queries, keys, values, mask=None, cache=None):
queries = self.query_proj(queries)
keys = self.key_proj(keys)
values = self.value_proj(values)
# Extract some shapes
num_heads = self.num_heads
B, L, D = queries.shape
# Prepare the queries, keys and values for the attention computation
queries = queries.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
keys = keys.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
values = values.reshape(B, L, num_heads, -1).transpose(0, 2, 1, 3)
# Add RoPE to the queries and keys and combine them with the cache
if cache is not None:
key_cache, value_cache = cache
queries = self.rope(queries, offset=key_cache.shape[2])
keys = self.rope(keys, offset=key_cache.shape[2])
keys = mx.concatenate([key_cache, keys], axis=2)
values = mx.concatenate([value_cache, values], axis=2)
else:
queries = self.rope(queries)
keys = self.rope(keys)
# Finally perform the attention computation
scale = math.sqrt(1 / queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 1, 3, 2)
if mask is not None:
scores = scores + mask
scores = mx.softmax(scores, axis=-1)
values_hat = (scores @ values).transpose(0, 2, 1, 3).reshape(B, L, -1)
# Note that we return the keys and values to possibly be used as a cache
return self.out_proj(values_hat), (keys, values)
class LlamaEncoderLayer(nn.Module):
def __init__(self, dims: int, mlp_dims: int, num_heads: int):
super().__init__()
self.attention = LlamaAttention(dims, num_heads)
self.norm1 = nn.RMSNorm(dims)
self.norm2 = nn.RMSNorm(dims)
self.linear1 = nn.Linear(dims, mlp_dims, bias=False)
self.linear2 = nn.Linear(dims, mlp_dims, bias=False)
self.linear3 = nn.Linear(mlp_dims, dims, bias=False)
def __call__(self, x, mask=None, cache=None):
y = self.norm1(x)
y, cache = self.attention(y, y, y, mask, cache)
x = x + y
y = self.norm2(x)
a = self.linear1(y)
b = self.linear2(y)
y = a * mx.sigmoid(a) * b
y = self.linear3(y)
x = x + y
return x, cache
class Llama(nn.Module):
def __init__(
self, num_layers: int, vocab_size: int, dims: int, mlp_dims: int, num_heads: int
):
super().__init__()
self.embedding = nn.Embedding(vocab_size, dims)
self.layers = [
LlamaEncoderLayer(dims, mlp_dims, num_heads) for _ in range(num_layers)
]
self.norm = nn.RMSNorm(dims)
self.out_proj = nn.Linear(dims, vocab_size, bias=False)
def __call__(self, x):
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
x = self.embedding(x)
for l in self.layers:
x, _ = l(x, mask)
x = self.norm(x)
return self.out_proj(x)
def generate(self, x, temp=1.0):
cache = []
# Make an additive causal mask. We will need that to process the prompt.
mask = nn.MultiHeadAttention.create_additive_causal_mask(x.shape[1])
mask = mask.astype(self.embedding.weight.dtype)
# First we process the prompt x the same was as in __call__ but
# save the caches in cache
x = self.embedding(x)
for l in self.layers:
x, c = l(x, mask=mask)
# We store the per layer cache in a simple python list
cache.append(c)
x = self.norm(x)
# We only care about the last logits that generate the next token
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
# y now has size [1]
# Since MLX is lazily evaluated nothing is computed yet.
# Calling y.item() would force the computation to happen at
# this point but we can also choose not to do that and let the
# user choose when to start the computation.
yield y
# Now we parsed the prompt and generated the first token we
# need to feed it back into the model and loop to generate the
# rest.
while True:
# Unsqueezing the last dimension to add a sequence length
# dimension of 1
x = y[:, None]
x = self.embedding(x)
for i in range(len(cache)):
# We are overwriting the arrays in the cache list. When
# the computation will happen, MLX will be discarding the
# old cache the moment it is not needed anymore.
x, cache[i] = self.layers[i](x, mask=None, cache=cache[i])
x = self.norm(x)
y = self.out_proj(x[:, -1])
y = mx.random.categorical(y * (1 / temp))
yield y
def tic():
return time.time()
def toc(msg, start):
end = time.time()
return f"[INFO] {msg}: {end - start:.3f} s"
def generate(args):
input("Press enter to start generation")
print("------")
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(args.prompt)])
skip = 0
prompt_processing = None
tokens = []
start = tic()
for token in model.generate(x, args.temp):
tokens.append(token)
if len(tokens) == 1:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = toc("Prompt processing", start)
if len(tokens) >= args.num_tokens:
break
elif (len(tokens) % args.write_every) == 0:
# It is perfectly ok to eval things we have already eval-ed.
mx.eval(tokens)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
skip = len(s)
mx.eval(tokens)
full_gen = toc("Full generation", start)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
print()
print("------")
print(prompt_processing)
print(full_gen)
def few_shot_generate(args):
def possible_end(s):
word = "[Instruction]"
for i in range(len(word) - 1, 0, -1):
if s[-i:] == word[:i]:
return 0
if s[-len(word) :] == word:
return 1
return -1
def generate(question):
x = mx.array([[tokenizer.bos_id()] + tokenizer.encode(question)])
skip = 0
prompt_processing = None
tokens = []
start = tic()
for token in model.generate(x, args.temp):
tokens.append(token)
if len(tokens) == 1:
# Actually perform the computation to measure the prompt processing time
mx.eval(token)
prompt_processing = toc("Prompt processing", start)
if len(tokens) >= args.num_tokens:
break
mx.eval(tokens)
token_list = [t.item() for t in tokens]
s = tokenizer.decode(token_list)
end = possible_end(s)
if end == 0:
continue
if end == 1:
skip = len(s)
break
print(s[skip:], end="", flush=True)
skip = len(s)
if token_list[-1] == tokenizer.eos_id():
break
mx.eval(tokens)
full_gen = toc("Full generation", start)
s = tokenizer.decode([t.item() for t in tokens])
print(s[skip:], end="", flush=True)
prompt = open(args.prompt).read().strip()
while True:
question = input("Ask a question: ")
generate(prompt.replace("{}", question))
print()
def load_model(model_path):
weights = mx.load(model_path)
mlp_dims, dims = weights["layers.0.linear1.weight"].shape
num_heads = dims // 128
num_layers = max(int(l.split(".")[1]) for l in weights.keys() if "layers" in l) + 1
vocab_size = weights["out_proj.weight"].shape[-1]
model = Llama(num_layers, vocab_size, dims, mlp_dims, num_heads)
model.update(tree_unflatten(list(weights.items())))
mx.eval(model.parameters())
return model
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Llama inference script")
parser.add_argument("model", help="The model file containing MLX weights")
parser.add_argument("tokenizer", help="The sentencepiece tokenizer")
parser.add_argument("prompt", help="The message to be processed by the model")
parser.add_argument(
"--few-shot",
action="store_true",
help="Read a few shot prompt from a file (as in `sample_prompt.txt`).",
)
parser.add_argument(
"--num-tokens", "-n", type=int, default=100, help="How many tokens to generate"
)
parser.add_argument(
"--write-every", type=int, default=1, help="After how many tokens to detokenize"
)
parser.add_argument(
"--temp", type=float, default=0.8, help="The sampling temperature"
)
parser.add_argument("--seed", type=int, default=0, help="The PRNG seed")
args = parser.parse_args()
mx.random.seed(args.seed)
tokenizer = SentencePieceProcessor(model_file=args.tokenizer)
print("[INFO] Loading model from disk.")
model = load_model(args.model)
if args.few_shot:
few_shot_generate(args)
else:
generate(args)

2
llama/requirements.txt Normal file
View File

@ -0,0 +1,2 @@
sentencepiece
torch

23
llama/sample_prompt.txt Normal file
View File

@ -0,0 +1,23 @@
[Instruction] Give the list of U.S. states bordering Canada
[Answer] OK, here is the list of U.S. states located on the border with Canada:
- Alaska
- Michigan
- Maine
- Minnesota
- Montana
- New York
- Washington
- North Dakota
- Ohio
- Vermont
- New Hampshire
- Idaho
- Pennsylvania
[Instruction] Write a paragraph about "functional analysis"
[Answer] OK, here is a paragraph on the topic of functional analysis:
Functional analysis is a branch of mathematical analysis, the core of which is formed by the study of vector spaces endowed with some kind of limit-related structure (for example, inner product, norm, or topology) and the linear functions defined on these spaces and suitably respecting these structures. The historical roots of functional analysis lie in the study of spaces of functions and the formulation of properties of transformations of functions such as the Fourier transform as transformations defining, for example, continuous or unitary operators between function spaces. This point of view turned out to be particularly useful for the study of differential and integral equations.
[Instruction] I am starting a new dog walking business. Can you help me find 2 possible names for the business?
[Answer] OK, here are two possible names for a new dog walking business:
The first option is "Paws on Patrol", and the second option is "The Dog Whisperer".
[Instruction] {}
[Answer]

View File

@ -0,0 +1,95 @@
Stable Diffusion
================
Stable Diffusion in MLX. The implementation was ported from Hugginface's
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
and using the weights available on the Huggingface Hub by Stability AI at
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
![out](generated-mlx.png)
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign saying MLX in capital letters.'*
Installation
------------
The dependencies are minimal, namely:
- `safetensors` and `huggingface-hub` to load the checkpoints.
- `regex` for the tokenization
- `numpy` because safetensors needs to return some form of array
- `tqdm` and `PIL` for the `txt2image.py` script
You can install all of the above with the `requirements.txt` as follows:
pip install -r requirements.txt
Usage
------
Although each component in this repository can be used by itsself, the fastest
way to get started is by using the `StableDiffusion` class from the `diffusion`
module.
```python
from stable_diffusion import StableDiffusion
# This will download all the weights from HF hub and load the models in
# memory
sd = StableDiffusion()
# This creates a python generator that returns the latent produced by the
# reverse diffusion process.
#
# Because MLX is lazily evaluated iterating over this generator doesn't
# actually perform the computation until mx.eval() is called.
latent_generator = sd.generate_latents("A photo of an astronaut riding a horse on Mars.")
# Here we are evaluating each diffusion step but we could also evaluate
# once at the end.
for x_t in latent_generator:
mx.simplify(x_t) # remove possible redundant computation eg reuse
# scalars etc
mx.eval(x_t)
# Now x_t is the last latent from the reverse process aka x_0. We can
# decode it into an image using the stable diffusion VAE.
im = sd.decode(x_t)
```
The above is almost line for line the implementation of the `txt2image.py`
script in the root of the repository. You can use the script as follows:
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
Performance
-----------
The following table compares the performance of the UNet in stable diffusion.
We report throughput in images per second for the provided `txt2image.py`
script and the `diffusers` library using the MPS PyTorch backend.
At the time of writing this comparison convolutions are still some of the least
optimized operations in MLX. Despite that, MLX still achieves **~40% higher
throughput** than PyTorch with a batch size of 16 and ~15% higher when
comparing the optimal batch sizes.
Notably, PyTorch achieves almost ~50% higher throughput for the batch size of 1
which is unfortunate as that means that a single image can be computed faster.
However, when starting with the models not loaded in memory and PyTorch's MPS
graph kernels not cached, the compilation time more than accounts for this
speed difference.
| Batch size | PyTorch | MLX |
| ---------- | ----------- | ----------- |
| 1 | 6.25 im/s | 4.17 im/s |
| 2 | 7.14 im/s | 5.88 im/s |
| 4 |**7.69 im/s**| 7.14 im/s |
| 6 | 7.22 im/s | 8.00 im/s |
| 8 | 6.89 im/s | 8.42 im/s |
| 12 | 6.62 im/s | 8.51 im/s |
| 16 | 6.32 im/s |**8.79 im/s**|
The above experiments were made on an M2 Ultra with PyTorch version 2.1,
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
used classifier free guidance which means that the above batch sizes result
double the images processed by the UNet.

Binary file not shown.

After

Width:  |  Height:  |  Size: 185 KiB

View File

@ -0,0 +1,6 @@
safetensors
huggingface-hub
regex
numpy
tqdm
Pillow

View File

@ -0,0 +1,96 @@
import time
from typing import Tuple
import mlx.core as mx
from .model_io import (
load_unet,
load_text_encoder,
load_autoencoder,
load_diffusion_config,
load_tokenizer,
_DEFAULT_MODEL,
)
from .sampler import SimpleEulerSampler
def _repeat(x, n, axis):
# Make the expanded shape
s = x.shape
s.insert(axis + 1, n)
# Expand
x = mx.broadcast_to(mx.expand_dims(x, axis + 1), s)
# Make the flattened shape
s.pop(axis + 1)
s[axis] *= n
return x.reshape(s)
class StableDiffusion:
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
self.dtype = mx.float16 if float16 else mx.float32
self.diffusion_config = load_diffusion_config(model)
self.unet = load_unet(model, float16)
self.text_encoder = load_text_encoder(model, float16)
self.autoencoder = load_autoencoder(model, float16)
self.sampler = SimpleEulerSampler(self.diffusion_config)
self.tokenizer = load_tokenizer(model)
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 50,
cfg_weight: float = 7.5,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Tokenize the text
tokens = [self.tokenizer.tokenize(text)]
if cfg_weight > 1:
tokens += [self.tokenizer.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
tokens = mx.array(tokens)
# Compute the features
conditioning = self.text_encoder(tokens)
# Repeat the conditioning for each of the generated images
if n_images > 1:
conditioning = _repeat(conditioning, n_images, axis=0)
# Create the latent variables
x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels),
dtype=self.dtype
)
# Perform the denoising loop
x_t = x_T
for t, t_prev in self.sampler.timesteps(num_steps, dtype=self.dtype):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
eps_pred = eps_neg + cfg_weight * (eps_text - eps_neg)
x_t_prev = self.sampler.step(eps_pred, x_t, t, t_prev)
x_t = x_t_prev
yield x_t
def decode(self, x_t):
x = self.autoencoder.decode(x_t / self.autoencoder.scaling_factor)
x = mx.minimum(1, mx.maximum(0, x / 2 + 0.5))
return x

View File

@ -0,0 +1,68 @@
import mlx.core as mx
import mlx.nn as nn
from .config import CLIPTextModelConfig
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
self.layer_norm2 = nn.LayerNorm(model_dims)
self.attention = nn.MultiHeadAttention(model_dims, num_heads)
# Add biases to the attention projections to match CLIP
self.attention.query_proj.bias = mx.zeros(model_dims)
self.attention.key_proj.bias = mx.zeros(model_dims)
self.attention.value_proj.bias = mx.zeros(model_dims)
self.attention.out_proj.bias = mx.zeros(model_dims)
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
x = y + x
y = self.layer_norm2(x)
y = self.linear1(y)
y = nn.gelu_approx(y)
y = self.linear2(y)
x = y + x
return x
class CLIPTextModel(nn.Module):
"""Implements the text encoder transformer from CLIP."""
def __init__(self, config: CLIPTextModelConfig):
super().__init__()
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
def __call__(self, x):
# Extract some shapes
B, N = x.shape
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
for l in self.layers:
x = l(x, mask)
# Apply the final layernorm and return
return self.final_layer_norm(x)

View File

@ -0,0 +1,46 @@
from dataclasses import dataclass
from typing import Optional, Tuple
@dataclass
class AutoencoderConfig:
in_channels: int = 3
out_channels: int = 3
latent_channels_out: int = 8
latent_channels_in: int = 4
block_out_channels: Tuple[int] = (128, 256, 512, 512)
layers_per_block: int = 2
norm_num_groups: int = 32
scaling_factor: float = 0.18215
@dataclass
class CLIPTextModelConfig:
num_layers: int = 23
model_dims: int = 1024
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
@dataclass
class UNetConfig:
in_channels: int = 4
out_channels: int = 4
conv_in_kernel: int = 3
conv_out_kernel: int = 3
block_out_channels: Tuple[int] = (320, 640, 1280, 1280)
layers_per_block: Tuple[int] = (2, 2, 2, 2)
mid_block_layers: int = 2
transformer_layers_per_block: Tuple[int] = (1, 1, 1, 1)
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
@dataclass
class DiffusionConfig:
beta_schedule: str = "scaled_linear"
beta_start: float = 0.00085
beta_end: float = 0.012
num_train_steps: int = 1000

View File

@ -0,0 +1,284 @@
import json
from functools import partial
import numpy as np
from huggingface_hub import hf_hub_download
from safetensors import safe_open as safetensor_open
import mlx.core as mx
from mlx.utils import tree_unflatten
from .clip import CLIPTextModel
from .config import UNetConfig, CLIPTextModelConfig, AutoencoderConfig, DiffusionConfig
from .tokenizer import Tokenizer
from .unet import UNetModel
from .vae import Autoencoder
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
_MODELS = {
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
"stabilityai/stable-diffusion-2-1-base": {
"unet_config": "unet/config.json",
"unet": "unet/diffusion_pytorch_model.safetensors",
"text_encoder_config": "text_encoder/config.json",
"text_encoder": "text_encoder/model.safetensors",
"vae_config": "vae/config.json",
"vae": "vae/diffusion_pytorch_model.safetensors",
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
}
}
def _from_numpy(x):
return mx.array(np.ascontiguousarray(x))
def map_unet_weights(key, value):
# Map up/downsampling
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map attention layers
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
key = key.replace("to_out.0", "out_proj")
if "to_q" in key:
key = key.replace("to_q", "query_proj")
if "to_v" in key:
key = key.replace("to_v", "value_proj")
# Map transformer ffn
if "ff.net.2" in key:
key = key.replace("ff.net.2", "linear3")
if "ff.net.0" in key:
k1 = key.replace("ff.net.0.proj", "linear1")
k2 = key.replace("ff.net.0.proj", "linear2")
v1, v2 = np.split(value, 2)
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
if "conv_shortcut.weight" in key:
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
return [(key, _from_numpy(value))]
def map_clip_text_encoder_weights(key, value):
# Remove prefixes
if key.startswith("text_model."):
key = key[11:]
if key.startswith("embeddings."):
key = key[11:]
if key.startswith("encoder."):
key = key[8:]
# Map attention layers
if "self_attn." in key:
key = key.replace("self_attn.", "attention.")
if "q_proj." in key:
key = key.replace("q_proj.", "query_proj.")
if "k_proj." in key:
key = key.replace("k_proj.", "key_proj.")
if "v_proj." in key:
key = key.replace("v_proj.", "value_proj.")
# Map ffn layers
if "mlp.fc1" in key:
key = key.replace("mlp.fc1", "linear1")
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
return [(key, _from_numpy(value))]
def map_vae_weights(key, value):
# Map up/downsampling
if "downsamplers" in key:
key = key.replace("downsamplers.0.conv", "downsample")
if "upsamplers" in key:
key = key.replace("upsamplers.0.conv", "upsample")
# Map attention layers
if "to_k" in key:
key = key.replace("to_k", "key_proj")
if "to_out.0" in key:
key = key.replace("to_out.0", "out_proj")
if "to_q" in key:
key = key.replace("to_q", "query_proj")
if "to_v" in key:
key = key.replace("to_v", "value_proj")
# Map the mid block
if "mid_block.resnets.0" in key:
key = key.replace("mid_block.resnets.0", "mid_blocks.0")
if "mid_block.attentions.0" in key:
key = key.replace("mid_block.attentions.0", "mid_blocks.1")
if "mid_block.resnets.1" in key:
key = key.replace("mid_block.resnets.1", "mid_blocks.2")
# Map the quant/post_quant layers
if "quant_conv" in key:
key = key.replace("quant_conv", "quant_proj")
value = value.squeeze()
# Map the conv_shortcut to linear
if "conv_shortcut.weight" in key:
value = value.squeeze()
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
return [(key, _from_numpy(value))]
def _flatten(params):
return [(k, v) for p in params for (k, v) in p]
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
dtype = np.float16 if float16 else np.float32
with safetensor_open(weight_file, framework="numpy") as f:
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
model.update(tree_unflatten(weights))
def _check_key(key: str, part: str):
if key not in _MODELS:
raise ValueError(
f"[{part}] '{key}' model not found, choose one of {{{','.join(_MODELS.keys())}}}"
)
def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion UNet from Huggingface Hub."""
_check_key(key, "load_unet")
# Download the config and create the model
unet_config = _MODELS[key]["unet_config"]
with open(hf_hub_download(key, unet_config)) as f:
config = json.load(f)
n_blocks = len(config["block_out_channels"])
model = UNetModel(
UNetConfig(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks,
num_attention_heads=config["attention_head_dim"],
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
)
)
# Download the weights and map them into the model
unet_weights = _MODELS[key]["unet"]
weight_file = hf_hub_download(key, unet_weights)
_load_safetensor_weights(map_unet_weights, model, weight_file, float16)
return model
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion text encoder from Huggingface Hub."""
_check_key(key, "load_text_encoder")
# Download the config and create the model
text_encoder_config = _MODELS[key]["text_encoder_config"]
with open(hf_hub_download(key, text_encoder_config)) as f:
config = json.load(f)
model = CLIPTextModel(
CLIPTextModelConfig(
num_layers=config["num_hidden_layers"],
model_dims=config["hidden_size"],
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
)
)
# Download the weights and map them into the model
text_encoder_weights = _MODELS[key]["text_encoder"]
weight_file = hf_hub_download(key, text_encoder_weights)
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
return model
def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
"""Load the stable diffusion autoencoder from Huggingface Hub."""
_check_key(key, "load_autoencoder")
# Download the config and create the model
vae_config = _MODELS[key]["vae_config"]
with open(hf_hub_download(key, vae_config)) as f:
config = json.load(f)
model = Autoencoder(
AutoencoderConfig(
in_channels=config["in_channels"],
out_channels=config["out_channels"],
latent_channels_out=2 * config["latent_channels"],
latent_channels_in=config["latent_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=config["layers_per_block"],
norm_num_groups=config["norm_num_groups"],
)
)
# Download the weights and map them into the model
vae_weights = _MODELS[key]["vae"]
weight_file = hf_hub_download(key, vae_weights)
_load_safetensor_weights(map_vae_weights, model, weight_file, float16)
return model
def load_diffusion_config(key: str = _DEFAULT_MODEL):
"""Load the stable diffusion config from Huggingface Hub."""
_check_key(key, "load_diffusion_config")
diffusion_config = _MODELS[key]["diffusion_config"]
with open(hf_hub_download(key, diffusion_config)) as f:
config = json.load(f)
return DiffusionConfig(
beta_start=config["beta_start"],
beta_end=config["beta_end"],
beta_schedule=config["beta_schedule"],
num_train_steps=config["num_train_timesteps"],
)
def load_tokenizer(key: str = _DEFAULT_MODEL):
_check_key(key, "load_tokenizer")
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]
bpe_ranks = dict(map(reversed, enumerate(bpe_merges)))
return Tokenizer(bpe_ranks, vocab)

View File

@ -0,0 +1,70 @@
from .config import DiffusionConfig
import mlx.core as mx
def _linspace(a, b, num):
x = mx.arange(0, num) / (num - 1)
return (b - a) * x + a
def _interp(y, x_new):
"""Interpolate the function defined by (arange(0, len(y)), y) at positions x_new."""
x_low = x_new.astype(mx.int32)
x_high = mx.minimum(x_low + 1, len(y) - 1)
y_low = y[x_low]
y_high = y[x_high]
delta_x = x_new - x_low
y_new = y_low * (1 - delta_x) + delta_x * y_high
return y_new
class SimpleEulerSampler:
"""A simple Euler integrator that can be used to sample from our diffusion models.
The method ``step()`` performs one Euler step from x_t to x_t_prev.
"""
def __init__(self, config: DiffusionConfig):
# Compute the noise schedule
if config.beta_schedule == "linear":
betas = _linspace(
config.beta_start, config.beta_end, config.num_train_steps
)
elif config.beta_schedule == "scaled_linear":
betas = _linspace(
config.beta_start**0.5, config.beta_end**0.5, config.num_train_steps
).square()
else:
raise NotImplementedError(f"{config.beta_schedule} is not implemented.")
alphas = 1 - betas
alphas_cumprod = mx.cumprod(alphas)
self._sigmas = mx.concatenate(
[mx.zeros(1), ((1 - alphas_cumprod) / alphas_cumprod).sqrt()]
)
def sample_prior(self, shape, dtype=mx.float32, key=None):
noise = mx.random.normal(shape, key=key)
return (noise * self._sigmas[-1] * (self._sigmas[-1].square() + 1).rsqrt()).astype(dtype)
def sigmas(self, t):
return _interp(self._sigmas, t)
def timesteps(self, num_steps: int, dtype=mx.float32):
steps = _linspace(len(self._sigmas) - 1, 0, num_steps + 1).astype(dtype)
return list(zip(steps, steps[1:]))
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
dt = sigma_prev - sigma
x_t_prev = (sigma.square() + 1).sqrt() * x_t + eps_pred * dt
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev

View File

@ -0,0 +1,95 @@
import regex
class Tokenizer:
"""A simple port of CLIPTokenizer from https://github.com/huggingface/transformers/ ."""
def __init__(self, bpe_ranks, vocab):
self.bpe_ranks = bpe_ranks
self.vocab = vocab
self.pat = regex.compile(
r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""",
regex.IGNORECASE,
)
self._cache = {self.bos: self.bos, self.eos: self.eos}
@property
def bos(self):
return "<|startoftext|>"
@property
def bos_token(self):
return self.vocab[self.bos]
@property
def eos(self):
return "<|endoftext|>"
@property
def eos_token(self):
return self.vocab[self.eos]
def bpe(self, text):
if text in self._cache:
return self._cache[text]
unigrams = list(text[:-1]) + [text[-1] + "</w>"]
unique_bigrams = set(zip(unigrams, unigrams[1:]))
if not unique_bigrams:
return unigrams
# In every iteration try to merge the two most likely bigrams. If none
# was merged we are done.
#
# Ported from https://github.com/huggingface/transformers/blob/main/src/transformers/models/clip/tokenization_clip.py
while unique_bigrams:
bigram = min(unique_bigrams, key=lambda pair: self.bpe_ranks.get(pair, float("inf")))
if bigram not in self.bpe_ranks:
break
new_unigrams = []
skip = False
for a, b in zip(unigrams, unigrams[1:]):
if skip:
skip = False
continue
if (a, b) == bigram:
new_unigrams.append(a + b)
skip = True
else:
new_unigrams.append(a)
if not skip:
new_unigrams.append(b)
unigrams = new_unigrams
unique_bigrams = set(zip(unigrams, unigrams[1:]))
self._cache[text] = unigrams
return unigrams
def tokenize(self, text, prepend_bos=True, append_eos=True):
if isinstance(text, list):
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
# Lower case cleanup and split according to self.pat. Huggingface does
# a much more thorough job here but this should suffice for 95% of
# cases.
clean_text = regex.sub(r"\s+", " ", text.lower())
tokens = regex.findall(self.pat, clean_text)
# Split the tokens according to the byte-pair merge file
bpe_tokens = [ti for t in tokens for ti in self.bpe(t)]
# Map to token ids and return
tokens = [self.vocab[t] for t in bpe_tokens]
if prepend_bos:
tokens = [self.bos_token] + tokens
if append_eos:
tokens.append(self.eos_token)
return tokens

View File

@ -0,0 +1,423 @@
import math
from typing import Optional
import mlx.core as mx
import mlx.nn as nn
from .config import UNetConfig
def upsample_nearest(x, scale: int = 2):
B, H, W, C = x.shape
x = mx.broadcast_to(x[:, :, None, :, None, :], (B, H, scale, W, scale, C))
x = x.reshape(B, H * scale, W * scale, C)
return x
class TimestepEmbedding(nn.Module):
def __init__(self, in_channels: int, time_embed_dim: int):
super().__init__()
self.linear_1 = nn.Linear(in_channels, time_embed_dim)
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim)
def __call__(self, x):
x = self.linear_1(x)
x = nn.silu(x)
x = self.linear_2(x)
return x
class TransformerBlock(nn.Module):
def __init__(
self,
model_dims: int,
num_heads: int,
hidden_dims: Optional[int] = None,
memory_dims: Optional[int] = None,
):
super().__init__()
self.norm1 = nn.LayerNorm(model_dims)
self.attn1 = nn.MultiHeadAttention(model_dims, num_heads)
self.attn1.out_proj.bias = mx.zeros(model_dims)
memory_dims = memory_dims or model_dims
self.norm2 = nn.LayerNorm(model_dims)
self.attn2 = nn.MultiHeadAttention(
model_dims, num_heads, key_input_dims=memory_dims
)
self.attn2.out_proj.bias = mx.zeros(model_dims)
hidden_dims = hidden_dims or 4 * model_dims
self.norm3 = nn.LayerNorm(model_dims)
self.linear1 = nn.Linear(model_dims, hidden_dims)
self.linear2 = nn.Linear(model_dims, hidden_dims)
self.linear3 = nn.Linear(hidden_dims, model_dims)
def __call__(self, x, memory, attn_mask, memory_mask):
# Self attention
y = self.norm1(x)
y = self.attn1(y, y, y, attn_mask)
x = x + y
# Cross attention
y = self.norm2(x)
y = self.attn2(y, memory, memory, memory_mask)
x = x + y
# FFN
y = self.norm3(x)
y_a = self.linear1(y)
y_b = self.linear2(y)
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
y = self.linear3(y)
x = x + y
return x
class Transformer2D(nn.Module):
"""A transformer model for inputs with 2 spatial dimensions."""
def __init__(
self,
in_channels: int,
model_dims: int,
encoder_dims: int,
num_heads: int,
num_layers: int = 1,
norm_num_groups: int = 32,
):
super().__init__()
self.norm = nn.GroupNorm(norm_num_groups, in_channels, pytorch_compatible=True)
self.proj_in = nn.Linear(in_channels, model_dims)
self.transformer_blocks = [
TransformerBlock(model_dims, num_heads, memory_dims=encoder_dims)
for i in range(num_layers)
]
self.proj_out = nn.Linear(model_dims, in_channels)
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
# Save the input to add to the output
input_x = x
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
for block in self.transformer_blocks:
x = block(x, encoder_x, attn_mask, encoder_attn_mask)
# Apply the output projection and reshape
x = self.proj_out(x)
x = x.reshape(B, H, W, C)
return x + input_x
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: Optional[int] = None,
groups: int = 32,
temb_channels: Optional[int] = None,
):
super().__init__()
out_channels = out_channels or in_channels
self.norm1 = nn.GroupNorm(groups, in_channels, pytorch_compatible=True)
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if temb_channels is not None:
self.time_emb_proj = nn.Linear(temb_channels, out_channels)
self.norm2 = nn.GroupNorm(groups, out_channels, pytorch_compatible=True)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
if in_channels != out_channels:
self.conv_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x, temb=None):
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y)
y = nn.silu(y)
y = self.conv2(y)
x = y + (x if "conv_shortcut" not in self else self.conv_shortcut(x))
return x
class UNetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
prev_out_channels: Optional[int] = None,
num_layers: int = 1,
transformer_layers_per_block: int = 1,
num_attention_heads: int = 8,
cross_attention_dim=1280,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
add_cross_attention=True,
):
super().__init__()
# Prepare the in channels list for the resnets
if prev_out_channels is None:
in_channels_list = [in_channels] + [out_channels] * (num_layers - 1)
else:
in_channels_list = [prev_out_channels] + [out_channels] * (num_layers - 1)
res_channels_list = [out_channels] * (num_layers - 1) + [in_channels]
in_channels_list = [
a + b for a, b in zip(in_channels_list, res_channels_list)
]
# Add resnet blocks that also process the time embedding
self.resnets = [
ResnetBlock2D(
in_channels=ic,
out_channels=out_channels,
temb_channels=temb_channels,
groups=resnet_groups,
)
for ic in in_channels_list
]
# Add optional cross attention layers
if add_cross_attention:
self.attentions = [
Transformer2D(
in_channels=out_channels,
model_dims=out_channels,
num_heads=num_attention_heads,
num_layers=transformer_layers_per_block,
encoder_dims=cross_attention_dim,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(
self,
x,
encoder_x=None,
temb=None,
attn_mask=None,
encoder_attn_mask=None,
residual_hidden_states=None,
):
output_states = []
for i in range(len(self.resnets)):
if residual_hidden_states is not None:
x = mx.concatenate([x, residual_hidden_states.pop()], axis=-1)
x = self.resnets[i](x, temb)
if "attentions" in self:
x = self.attentions[i](x, encoder_x, attn_mask, encoder_attn_mask)
output_states.append(x)
if "downsample" in self:
x = self.downsample(x)
output_states.append(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
output_states.append(x)
return x, output_states
class UNetModel(nn.Module):
"""The conditional 2D UNet model that actually performs the denoising."""
def __init__(self, config: UNetConfig):
super().__init__()
self.conv_in = nn.Conv2d(
config.in_channels,
config.block_out_channels[0],
config.conv_in_kernel,
padding=(config.conv_in_kernel - 1) // 2,
)
self.timesteps = nn.SinusoidalPositionalEncoding(
config.block_out_channels[0],
max_freq=1,
min_freq=math.exp(
-math.log(10000) + 2 * math.log(10000) / config.block_out_channels[0]
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.time_embedding = TimestepEmbedding(
config.block_out_channels[0],
config.block_out_channels[0] * 4,
)
# Make the downsampling blocks
block_channels = [config.block_out_channels[0]] + list(
config.block_out_channels
)
self.down_blocks = [
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
num_layers=config.layers_per_block[i],
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention=(i < len(config.block_out_channels) - 1),
)
for i, (in_channels, out_channels) in enumerate(
zip(block_channels, block_channels[1:])
)
]
# Make the middle block
self.mid_blocks = [
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
Transformer2D(
in_channels=config.block_out_channels[-1],
model_dims=config.block_out_channels[-1],
num_heads=config.num_attention_heads[-1],
num_layers=config.transformer_layers_per_block[-1],
encoder_dims=config.cross_attention_dim[-1],
),
ResnetBlock2D(
in_channels=config.block_out_channels[-1],
out_channels=config.block_out_channels[-1],
temb_channels=config.block_out_channels[0] * 4,
groups=config.norm_num_groups,
),
]
# Make the upsampling blocks
block_channels = (
[config.block_out_channels[0]]
+ list(config.block_out_channels)
+ [config.block_out_channels[-1]]
)
self.up_blocks = [
UNetBlock2D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=config.block_out_channels[0] * 4,
prev_out_channels=prev_out_channels,
num_layers=config.layers_per_block[i] + 1,
transformer_layers_per_block=config.transformer_layers_per_block[i],
num_attention_heads=config.num_attention_heads[i],
cross_attention_dim=config.cross_attention_dim[i],
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention=(i < len(config.block_out_channels) - 1),
)
for i, (in_channels, out_channels, prev_out_channels) in reversed(
list(
enumerate(
zip(block_channels, block_channels[1:], block_channels[2:])
)
)
)
]
self.conv_norm_out = nn.GroupNorm(
config.norm_num_groups,
config.block_out_channels[0],
pytorch_compatible=True,
)
self.conv_out = nn.Conv2d(
config.block_out_channels[0],
config.out_channels,
config.conv_out_kernel,
padding=(config.conv_out_kernel - 1) // 2,
)
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)
# Preprocess the input
x = self.conv_in(x)
# Run the downsampling part of the unet
residuals = [x]
for block in self.down_blocks:
x, res = block(
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
)
residuals.extend(res)
# Run the middle part of the unet
x = self.mid_blocks[0](x, temb)
x = self.mid_blocks[1](x, encoder_x, attn_mask, encoder_attn_mask)
x = self.mid_blocks[2](x, temb)
# Run the upsampling part of the unet
for block in self.up_blocks:
x, _ = block(
x,
encoder_x=encoder_x,
temb=temb,
attn_mask=attn_mask,
encoder_attn_mask=encoder_attn_mask,
residual_hidden_states=residuals,
)
# Postprocess the output
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x

View File

@ -0,0 +1,266 @@
import math
from typing import List
import mlx.core as mx
import mlx.nn as nn
from .config import AutoencoderConfig
from .unet import ResnetBlock2D, upsample_nearest
class Attention(nn.Module):
"""A single head unmasked attention for use with the VAE."""
def __init__(self, dims: int, norm_groups: int = 32):
super().__init__()
self.group_norm = nn.GroupNorm(norm_groups, dims, pytorch_compatible=True)
self.query_proj = nn.Linear(dims, dims)
self.key_proj = nn.Linear(dims, dims)
self.value_proj = nn.Linear(dims, dims)
self.out_proj = nn.Linear(dims, dims)
def __call__(self, x):
B, H, W, C = x.shape
y = self.group_norm(x)
queries = self.query_proj(y).reshape(B, H * W, C)
keys = self.key_proj(y).reshape(B, H * W, C)
values = self.value_proj(y).reshape(B, H * W, C)
scale = 1 / math.sqrt(queries.shape[-1])
scores = (queries * scale) @ keys.transpose(0, 2, 1)
attn = mx.softmax(scores, axis=-1)
y = (attn @ values).reshape(B, H, W, C)
y = self.out_proj(y)
x = x + y
return x
class EncoderDecoderBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
num_layers: int = 1,
resnet_groups: int = 32,
add_downsample=True,
add_upsample=True,
):
super().__init__()
# Add the resnet blocks
self.resnets = [
ResnetBlock2D(
in_channels=in_channels if i == 0 else out_channels,
out_channels=out_channels,
groups=resnet_groups,
)
for i in range(num_layers)
]
# Add an optional downsampling layer
if add_downsample:
self.downsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=2, padding=1
)
# or upsampling layer
if add_upsample:
self.upsample = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=1, padding=1
)
def __call__(self, x):
for resnet in self.resnets:
x = resnet(x)
if "downsample" in self:
x = self.downsample(x)
if "upsample" in self:
x = self.upsample(upsample_nearest(x))
return x
class Encoder(nn.Module):
"""Implements the encoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
out_channels: int,
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
):
super().__init__()
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1
)
channels = [block_out_channels[0]] + list(block_out_channels)
self.down_blocks = [
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=i < len(block_out_channels) - 1,
add_upsample=False,
)
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
]
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[-1], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[-1], out_channels, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
for l in self.down_blocks:
x = l(x)
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Decoder(nn.Module):
"""Implements the decoder side of the Autoencoder."""
def __init__(
self,
in_channels: int,
out_channels: int,
block_out_channels: List[int] = [64],
layers_per_block: int = 2,
resnet_groups: int = 32,
):
super().__init__()
self.conv_in = nn.Conv2d(
in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1
)
self.mid_blocks = [
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
Attention(block_out_channels[-1], resnet_groups),
ResnetBlock2D(
in_channels=block_out_channels[-1],
out_channels=block_out_channels[-1],
groups=resnet_groups,
),
]
channels = list(reversed(block_out_channels))
channels = [channels[0]] + channels
self.up_blocks = [
EncoderDecoderBlock2D(
in_channels,
out_channels,
num_layers=layers_per_block,
resnet_groups=resnet_groups,
add_downsample=False,
add_upsample=i < len(block_out_channels) - 1,
)
for i, (in_channels, out_channels) in enumerate(zip(channels, channels[1:]))
]
self.conv_norm_out = nn.GroupNorm(
resnet_groups, block_out_channels[0], pytorch_compatible=True
)
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
def __call__(self, x):
x = self.conv_in(x)
x = self.mid_blocks[0](x)
x = self.mid_blocks[1](x)
x = self.mid_blocks[2](x)
for l in self.up_blocks:
x = l(x)
x = self.conv_norm_out(x)
x = nn.silu(x)
x = self.conv_out(x)
return x
class Autoencoder(nn.Module):
"""The autoencoder that allows us to perform diffusion in the latent space."""
def __init__(self, config: AutoencoderConfig):
super().__init__()
self.latent_channels = config.latent_channels_in
self.scaling_factor = config.scaling_factor
self.encoder = Encoder(
config.in_channels,
config.latent_channels_out,
config.block_out_channels,
config.layers_per_block,
resnet_groups=config.norm_num_groups,
)
self.decoder = Decoder(
config.latent_channels_in,
config.out_channels,
config.block_out_channels,
config.layers_per_block + 1,
resnet_groups=config.norm_num_groups,
)
self.quant_proj = nn.Linear(
config.latent_channels_out, config.latent_channels_out
)
self.post_quant_proj = nn.Linear(
config.latent_channels_in, config.latent_channels_in
)
def decode(self, z):
return self.decoder(self.post_quant_proj(z))
def __call__(self, x, key=None):
x = self.encoder(x)
x = self.query_proj(x)
mean, logvar = x.split(2, axis=-1)
std = mx.exp(0.5 * logvar)
z = mx.random.normal(mean.shape, key=key) * std + mean
x_hat = self.decode(z)
return dict(x_hat=x_hat, z=z, mean=mean, logvar=logvar)

View File

@ -0,0 +1,55 @@
import argparse
from PIL import Image
from tqdm import tqdm
import mlx.core as mx
from stable_diffusion import StableDiffusion
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--n_images", type=int, default=4)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--cfg", type=float, default=7.5)
parser.add_argument("--negative_prompt", default="")
parser.add_argument("--n_rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1)
parser.add_argument("--output", default="out.png")
args = parser.parse_args()
sd = StableDiffusion()
# Generate the latent vectors using diffusion
latents = sd.generate_latents(
args.prompt,
n_images=args.n_images,
cfg_weight=args.cfg,
num_steps=args.steps,
negative_text=args.negative_prompt,
)
for x_t in tqdm(latents, total=args.steps):
mx.simplify(x_t)
mx.simplify(x_t)
mx.eval(x_t)
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
mx.eval(decoded[-1])
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
x = mx.pad(x, [(0, 0), (8, 8), (8, 8), (0, 0)])
B, H, W, C = x.shape
x = x.reshape(args.n_rows, B // args.n_rows, H, W, C).transpose(0, 2, 1, 3, 4)
x = x.reshape(args.n_rows * H, B // args.n_rows * W, C)
x = (x * 255).astype(mx.uint8)
# Save them to disc
im = Image.fromarray(x.__array__())
im.save(args.output)