Stable diffusion XL (#516)

This commit is contained in:
Angelos Katharopoulos 2024-03-08 10:24:19 -08:00 committed by GitHub
parent 8c2cf665ed
commit 3a9e6c3f70
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 449 additions and 105 deletions

View File

@ -2,22 +2,25 @@ Stable Diffusion
================
Stable Diffusion in MLX. The implementation was ported from Hugging Face's
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
and using the weights available on the Hugging Face Hub by Stability AI at
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
[diffusers](https://huggingface.co/docs/diffusers/index) and model weights are
downloaded directly from the Hugging Face hub. The implementation currently
supports the following models:
- [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1)
![out](generated-mlx.png)
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign saying MLX in capital letters.'*
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign
saying MLX in capital letters.'*
Installation
------------
The dependencies are minimal, namely:
- `safetensors` and `huggingface-hub` to load the checkpoints.
- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `numpy` because safetensors needs to return some form of array
- `tqdm` and `PIL` for the `txt2image.py` script
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
You can install all of the above with the `requirements.txt` as follows:
@ -43,7 +46,9 @@ sd = StableDiffusion()
#
# Because MLX is lazily evaluated iterating over this generator doesn't
# actually perform the computation until mx.eval() is called.
latent_generator = sd.generate_latents("A photo of an astronaut riding a horse on Mars.")
latent_generator = sd.generate_latents(
"A photo of an astronaut riding a horse on Mars."
)
# Here we are evaluating each diffusion step but we could also evaluate
# once at the end.
@ -55,10 +60,16 @@ for x_t in latent_generator:
im = sd.decode(x_t)
```
The above is almost line for line the implementation of the `txt2image.py`
script in the root of the repository. You can use the script as follows:
The above is essentially the implementation of the `txt2image.py` script in the
root of the repository. You can use the script as follows:
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
```shell
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
```
You can select the model using `--model` argument. Currently supported models
are `sdxl` (default) and `sd`.
Image 2 Image
-------------
@ -71,48 +82,36 @@ to the forward diffusion process and the `strength` parameter. A `strength` of
random noise.
![image2image](im2im.png)
*Generations with varying strength using the original image and the prompt 'A lit fireplace'.*
The command to generate the above images is:
python image2image.py --strength 0.5 original.png 'A lit fireplace'
```shell
python image2image.py --strength 0.5 original.png 'A lit fireplace'
```
*Note: `image2image.py` will automatically downsample your input image to guarantee that its dimensions are divisible by 64. If you want full control of this process, resize your image prior to using the script.*
> [!Note]
> `image2image.py` will automatically downsample your input image to guarantee
> that its dimensions are divisible by 64. If you want full control of this
> process, resize your image prior to using the script.
Performance
-----------
Memory constrained devices
--------------------------
The following table compares the performance of the UNet in stable diffusion.
We report throughput in images per second **processed by the UNet** for the
provided `txt2image.py` script and the `diffusers` library using the MPS
PyTorch backend.
The `txt2image.py` script by default loads the model in float16 which reduces
significantly the required memory for image generation. However, since the
Stable Diffusion XL UNet alone has 2.6B parameters in order to use it in
devices with 8GB of RAM, quantization is practically necessary.
At the time of writing this comparison convolutions are still some of the least
optimized operations in MLX. Despite that, MLX still achieves **~40% higher
throughput** than PyTorch with a batch size of 16 and ~15% higher when
comparing the optimal batch sizes.
The `txt2image.py` script supports quantization using the `-q` or `--quantize`
command line arguments. When quantization is used, the script quantizes the
text encoder models to 4 bits and the unet to 8 bits. This allows generating
images on an 8GB Mac Mini with no-swapping.
Notably, PyTorch achieves almost ~50% higher throughput for the batch size of 1
which is unfortunate as that means that a single image can be computed faster.
However, when starting with the models not loaded in memory and PyTorch's MPS
graph kernels not cached, the compilation time more than accounts for this
speed difference.
```
python txt2image.py --n_images 4 -q -v --output still-life.png "A painting of a vase on a wooden table, dark background, still life."
```
| Batch size | PyTorch | MLX |
| ---------- | ----------- | ----------- |
| 1 | 6.25 im/s | 4.17 im/s |
| 2 | 7.14 im/s | 5.88 im/s |
| 4 |**7.69 im/s**| 7.14 im/s |
| 6 | 7.22 im/s | 8.00 im/s |
| 8 | 6.89 im/s | 8.42 im/s |
| 12 | 6.62 im/s | 8.51 im/s |
| 16 | 6.32 im/s |**8.79 im/s**|
The above experiments were made on an M2 Ultra with PyTorch version 2.1,
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
used classifier free guidance which means that the above batch sizes result
double the images processed by the UNet.
Note that the above table means that it takes about 90 seconds to fully
generate 16 images with MLX and 50 diffusion steps with classifier free
guidance and about 120 for PyTorch.
![painting](still-life.png)
*Image generated using Stable Diffusion XL turbo in MLX with the above command on an 8GB M1 Mac mini*

View File

@ -22,10 +22,19 @@ if __name__ == "__main__":
parser.add_argument("--negative_prompt", default="")
parser.add_argument("--n_rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1)
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--no-float16", dest="float16", action="store_false")
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
sd = StableDiffusion()
sd = StableDiffusion("stabilityai/stable-diffusion-2-1-base", float16=args.float16)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
if args.preload_models:
sd.ensure_models_are_loaded()
# Read the image
img = Image.open(args.image)
@ -52,11 +61,20 @@ if __name__ == "__main__":
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
mx.eval(x_t)
# The following is not necessary but it may help in memory
# constrained systems by reusing the memory kept by the unet and the text
# encoders.
del sd.text_encoder
del sd.unet
del sd.sampler
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
mx.eval(decoded[-1])
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
@ -69,3 +87,8 @@ if __name__ == "__main__":
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")

View File

@ -1,5 +1,4 @@
mlx>=0.1
safetensors
huggingface-hub
regex
numpy

View File

@ -1,7 +1,7 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import time
from typing import Tuple
from typing import Optional, Tuple
import mlx.core as mx
@ -13,7 +13,7 @@ from .model_io import (
load_tokenizer,
load_unet,
)
from .sampler import SimpleEulerSampler
from .sampler import SimpleEulerAncestralSampler, SimpleEulerSampler
class StableDiffusion:
@ -22,10 +22,27 @@ class StableDiffusion:
self.diffusion_config = load_diffusion_config(model)
self.unet = load_unet(model, float16)
self.text_encoder = load_text_encoder(model, float16)
self.autoencoder = load_autoencoder(model, float16)
self.autoencoder = load_autoencoder(model, False)
self.sampler = SimpleEulerSampler(self.diffusion_config)
self.tokenizer = load_tokenizer(model)
def ensure_models_are_loaded(self):
mx.eval(self.unet.parameters())
mx.eval(self.text_encoder.parameters())
mx.eval(self.autoencoder.parameters())
def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None):
# Tokenize the text
tokens = [tokenizer.tokenize(text)]
if negative_text is not None:
tokens += [tokenizer.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
tokens = mx.array(tokens)
return tokens
def _get_text_conditioning(
self,
text: str,
@ -34,16 +51,12 @@ class StableDiffusion:
negative_text: str = "",
):
# Tokenize the text
tokens = [self.tokenizer.tokenize(text)]
if cfg_weight > 1:
tokens += [self.tokenizer.tokenize(negative_text)]
lengths = [len(t) for t in tokens]
N = max(lengths)
tokens = [t + [0] * (N - len(t)) for t in tokens]
tokens = mx.array(tokens)
tokens = self._tokenize(
self.tokenizer, text, (negative_text if cfg_weight > 1 else None)
)
# Compute the features
conditioning = self.text_encoder(tokens)
conditioning = self.text_encoder(tokens).last_hidden_state
# Repeat the conditioning for each of the generated images
if n_images > 1:
@ -51,10 +64,14 @@ class StableDiffusion:
return conditioning
def _denoising_step(self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5):
def _denoising_step(
self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5, text_time=None
):
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
eps_pred = self.unet(
x_t_unet, t_unet, encoder_x=conditioning, text_time=text_time
)
if cfg_weight > 1:
eps_text, eps_neg = eps_pred.split(2)
@ -65,13 +82,21 @@ class StableDiffusion:
return x_t_prev
def _denoising_loop(
self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5
self,
x_T,
T,
conditioning,
num_steps: int = 50,
cfg_weight: float = 7.5,
text_time=None,
):
x_t = x_T
for t, t_prev in self.sampler.timesteps(
num_steps, start_time=T, dtype=self.dtype
):
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
x_t = self._denoising_step(
x_t, t, t_prev, conditioning, cfg_weight, text_time
)
yield x_t
def generate_latents(
@ -142,3 +167,100 @@ class StableDiffusion:
x = self.autoencoder.decode(x_t)
x = mx.clip(x / 2 + 0.5, 0, 1)
return x
class StableDiffusionXL(StableDiffusion):
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
super().__init__(model, float16)
self.sampler = SimpleEulerAncestralSampler(self.diffusion_config)
self.text_encoder_1 = self.text_encoder
self.tokenizer_1 = self.tokenizer
del self.tokenizer, self.text_encoder
self.text_encoder_2 = load_text_encoder(
model,
float16,
model_key="text_encoder_2",
)
self.tokenizer_2 = load_tokenizer(
model,
merges_key="tokenizer_2_merges",
vocab_key="tokenizer_2_vocab",
)
def ensure_models_are_loaded(self):
mx.eval(self.unet.parameters())
mx.eval(self.text_encoder_1.parameters())
mx.eval(self.text_encoder_2.parameters())
mx.eval(self.autoencoder.parameters())
def _get_text_conditioning(
self,
text: str,
n_images: int = 1,
cfg_weight: float = 7.5,
negative_text: str = "",
):
tokens_1 = self._tokenize(
self.tokenizer_1,
text,
(negative_text if cfg_weight > 1 else None),
)
tokens_2 = self._tokenize(
self.tokenizer_2,
text,
(negative_text if cfg_weight > 1 else None),
)
conditioning_1 = self.text_encoder_1(tokens_1)
conditioning_2 = self.text_encoder_2(tokens_2)
conditioning = mx.concatenate(
[conditioning_1.hidden_states[-2], conditioning_2.hidden_states[-2]],
axis=-1,
)
pooled_conditioning = conditioning_2.pooled_output
if n_images > 1:
conditioning = mx.repeat(conditioning, n_images, axis=0)
return conditioning, pooled_conditioning
def generate_latents(
self,
text: str,
n_images: int = 1,
num_steps: int = 2,
cfg_weight: float = 0.0,
negative_text: str = "",
latent_size: Tuple[int] = (64, 64),
seed=None,
):
# Set the PRNG state
seed = seed or int(time.time())
mx.random.seed(seed)
# Get the text conditioning
conditioning, pooled_conditioning = self._get_text_conditioning(
text, n_images, cfg_weight, negative_text
)
text_time = (
pooled_conditioning,
mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)),
)
# Create the latent variables
x_T = self.sampler.sample_prior(
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
)
# Perform the denoising loop
yield from self._denoising_loop(
x_T,
self.sampler.max_time,
conditioning,
num_steps,
cfg_weight,
text_time=text_time,
)

View File

@ -1,15 +1,33 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import List, Optional
import mlx.core as mx
import mlx.nn as nn
from .config import CLIPTextModelConfig
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
@dataclass
class CLIPOutput:
# The last_hidden_state indexed at the EOS token and possibly projected if
# the model has a projection layer
pooled_output: Optional[mx.array] = None
# The full sequence output of the transformer after the final layernorm
last_hidden_state: Optional[mx.array] = None
# A list of hidden states corresponding to the outputs of the transformer layers
hidden_states: Optional[List[mx.array]] = None
class CLIPEncoderLayer(nn.Module):
"""The transformer encoder layer from CLIP."""
def __init__(self, model_dims: int, num_heads: int):
def __init__(self, model_dims: int, num_heads: int, activation: str):
super().__init__()
self.layer_norm1 = nn.LayerNorm(model_dims)
@ -25,6 +43,8 @@ class CLIPEncoderLayer(nn.Module):
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
self.linear2 = nn.Linear(4 * model_dims, model_dims)
self.act = _ACTIVATIONS[activation]
def __call__(self, x, attn_mask=None):
y = self.layer_norm1(x)
y = self.attention(y, y, y, attn_mask)
@ -32,7 +52,7 @@ class CLIPEncoderLayer(nn.Module):
y = self.layer_norm2(x)
y = self.linear1(y)
y = nn.gelu_approx(y)
y = self.act(y)
y = self.linear2(y)
x = y + x
@ -48,23 +68,49 @@ class CLIPTextModel(nn.Module):
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
self.layers = [
CLIPEncoderLayer(config.model_dims, config.num_heads)
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
for i in range(config.num_layers)
]
self.final_layer_norm = nn.LayerNorm(config.model_dims)
if config.projection_dim is not None:
self.text_projection = nn.Linear(
config.model_dims, config.projection_dim, bias=False
)
def _get_mask(self, N, dtype):
indices = mx.arange(N)
mask = indices[:, None] < indices[None]
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
return mask
def __call__(self, x):
# Extract some shapes
B, N = x.shape
eos_tokens = x.argmax(-1)
# Compute the embeddings
x = self.token_embedding(x)
x = x + self.position_embedding.weight[:N]
# Compute the features from the transformer
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
mask = self._get_mask(N, x.dtype)
hidden_states = []
for l in self.layers:
x = l(x, mask)
hidden_states.append(x)
# Apply the final layernorm and return
return self.final_layer_norm(x)
x = self.final_layer_norm(x)
last_hidden_state = x
# Select the EOS token
pooled_output = x[mx.arange(len(x)), eos_tokens]
if "text_projection" in self:
pooled_output = self.text_projection(pooled_output)
return CLIPOutput(
pooled_output=pooled_output,
last_hidden_state=last_hidden_state,
hidden_states=hidden_states,
)

View File

@ -1,4 +1,4 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
from dataclasses import dataclass
from typing import Optional, Tuple
@ -23,6 +23,8 @@ class CLIPTextModelConfig:
num_heads: int = 16
max_length: int = 77
vocab_size: int = 49408
projection_dim: Optional[int] = None
hidden_act: str = "quick_gelu"
@dataclass
@ -38,6 +40,21 @@ class UNetConfig:
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
cross_attention_dim: Tuple[int] = (1024,) * 4
norm_num_groups: int = 32
down_block_types: Tuple[str] = (
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"CrossAttnDownBlock2D",
"DownBlock2D",
)
up_block_types: Tuple[str] = (
"UpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
"CrossAttnUpBlock2D",
)
addition_embed_type: Optional[str] = None
addition_time_embed_dim: Optional[int] = None
projection_class_embeddings_input_dim: Optional[int] = None
@dataclass

View File

@ -1,13 +1,12 @@
# Copyright © 2023 Apple Inc.
# Copyright © 2023-2024 Apple Inc.
import json
from functools import partial
from typing import Optional
import mlx.core as mx
import numpy as np
from huggingface_hub import hf_hub_download
from mlx.utils import tree_unflatten
from safetensors import safe_open as safetensor_open
from .clip import CLIPTextModel
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
@ -17,6 +16,22 @@ from .vae import Autoencoder
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
_MODELS = {
# See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
"stabilityai/sdxl-turbo": {
"unet_config": "unet/config.json",
"unet": "unet/diffusion_pytorch_model.safetensors",
"text_encoder_config": "text_encoder/config.json",
"text_encoder": "text_encoder/model.safetensors",
"text_encoder_2_config": "text_encoder_2/config.json",
"text_encoder_2": "text_encoder_2/model.safetensors",
"vae_config": "vae/config.json",
"vae": "vae/diffusion_pytorch_model.safetensors",
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
"tokenizer_2_merges": "tokenizer_2/merges.txt",
},
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
"stabilityai/stable-diffusion-2-1-base": {
"unet_config": "unet/config.json",
@ -28,14 +43,10 @@ _MODELS = {
"diffusion_config": "scheduler/scheduler_config.json",
"tokenizer_vocab": "tokenizer/vocab.json",
"tokenizer_merges": "tokenizer/merges.txt",
}
},
}
def _from_numpy(x):
return mx.array(np.ascontiguousarray(x))
def map_unet_weights(key, value):
# Map up/downsampling
if "downsamplers" in key:
@ -67,9 +78,9 @@ def map_unet_weights(key, value):
if "ff.net.0" in key:
k1 = key.replace("ff.net.0.proj", "linear1")
k2 = key.replace("ff.net.0.proj", "linear2")
v1, v2 = np.split(value, 2)
v1, v2 = mx.split(value, 2)
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
return [(k1, v1), (k2, v2)]
if "conv_shortcut.weight" in key:
value = value.squeeze()
@ -80,8 +91,9 @@ def map_unet_weights(key, value):
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
return [(key, _from_numpy(value))]
return [(key, value)]
def map_clip_text_encoder_weights(key, value):
@ -109,7 +121,7 @@ def map_clip_text_encoder_weights(key, value):
if "mlp.fc2" in key:
key = key.replace("mlp.fc2", "linear2")
return [(key, _from_numpy(value))]
return [(key, value)]
def map_vae_weights(key, value):
@ -148,8 +160,9 @@ def map_vae_weights(key, value):
if len(value.shape) == 4:
value = value.transpose(0, 2, 3, 1)
value = value.reshape(-1).reshape(value.shape)
return [(key, _from_numpy(value))]
return [(key, value)]
def _flatten(params):
@ -157,9 +170,9 @@ def _flatten(params):
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
dtype = np.float16 if float16 else np.float32
with safetensor_open(weight_file, framework="numpy") as f:
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
dtype = mx.float16 if float16 else mx.float32
weights = mx.load(weight_file)
weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()])
model.update(tree_unflatten(weights))
@ -186,6 +199,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
out_channels=config["out_channels"],
block_out_channels=config["block_out_channels"],
layers_per_block=[config["layers_per_block"]] * n_blocks,
transformer_layers_per_block=config.get(
"transformer_layers_per_block", (1,) * 4
),
num_attention_heads=(
[config["attention_head_dim"]] * n_blocks
if isinstance(config["attention_head_dim"], int)
@ -193,6 +209,13 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
),
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
norm_num_groups=config["norm_num_groups"],
down_block_types=config["down_block_types"],
up_block_types=config["up_block_types"][::-1],
addition_embed_type=config.get("addition_embed_type", None),
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
projection_class_embeddings_input_dim=config.get(
"projection_class_embeddings_input_dim", None
),
)
)
@ -204,15 +227,24 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
return model
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
def load_text_encoder(
key: str = _DEFAULT_MODEL,
float16: bool = False,
model_key: str = "text_encoder",
config_key: Optional[str] = None,
):
"""Load the stable diffusion text encoder from Hugging Face Hub."""
_check_key(key, "load_text_encoder")
config_key = config_key or (model_key + "_config")
# Download the config and create the model
text_encoder_config = _MODELS[key]["text_encoder_config"]
text_encoder_config = _MODELS[key][config_key]
with open(hf_hub_download(key, text_encoder_config)) as f:
config = json.load(f)
with_projection = "WithProjection" in config["architectures"][0]
model = CLIPTextModel(
CLIPTextModelConfig(
num_layers=config["num_hidden_layers"],
@ -220,11 +252,13 @@ def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
num_heads=config["num_attention_heads"],
max_length=config["max_position_embeddings"],
vocab_size=config["vocab_size"],
projection_dim=config["projection_dim"] if with_projection else None,
hidden_act=config.get("hidden_act", "quick_gelu"),
)
)
# Download the weights and map them into the model
text_encoder_weights = _MODELS[key]["text_encoder"]
text_encoder_weights = _MODELS[key][model_key]
weight_file = hf_hub_download(key, text_encoder_weights)
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
@ -249,6 +283,7 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
block_out_channels=config["block_out_channels"],
layers_per_block=config["layers_per_block"],
norm_num_groups=config["norm_num_groups"],
scaling_factor=config.get("scaling_factor", 0.18215),
)
)
@ -276,14 +311,18 @@ def load_diffusion_config(key: str = _DEFAULT_MODEL):
)
def load_tokenizer(key: str = _DEFAULT_MODEL):
def load_tokenizer(
key: str = _DEFAULT_MODEL,
vocab_key: str = "tokenizer_vocab",
merges_key: str = "tokenizer_merges",
):
_check_key(key, "load_tokenizer")
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
vocab_file = hf_hub_download(key, _MODELS[key][vocab_key])
with open(vocab_file, encoding="utf-8") as f:
vocab = json.load(f)
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
merges_file = hf_hub_download(key, _MODELS[key][merges_key])
with open(merges_file, encoding="utf-8") as f:
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
bpe_merges = [tuple(m.split()) for m in bpe_merges]

View File

@ -83,3 +83,23 @@ class SimpleEulerSampler:
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
return x_t_prev
class SimpleEulerAncestralSampler(SimpleEulerSampler):
def step(self, eps_pred, x_t, t, t_prev):
sigma = self.sigmas(t).astype(eps_pred.dtype)
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
sigma2 = sigma.square()
sigma_prev2 = sigma_prev.square()
sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
dt = sigma_down - sigma
x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
x_t_prev = x_t_prev + noise * sigma_up
x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
return x_t_prev

View File

@ -74,7 +74,7 @@ class TransformerBlock(nn.Module):
y = self.norm3(x)
y_a = self.linear1(y)
y_b = self.linear2(y)
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
y = y_a * nn.gelu(y_b)
y = self.linear3(y)
x = x + y
@ -106,10 +106,11 @@ class Transformer2D(nn.Module):
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
# Save the input to add to the output
input_x = x
dtype = x.dtype
# Perform the input norm and projection
B, H, W, C = x.shape
x = self.norm(x).reshape(B, -1, C)
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
x = self.proj_in(x)
# Apply the transformer
@ -150,15 +151,17 @@ class ResnetBlock2D(nn.Module):
self.conv_shortcut = nn.Linear(in_channels, out_channels)
def __call__(self, x, temb=None):
dtype = x.dtype
if temb is not None:
temb = self.time_emb_proj(nn.silu(temb))
y = self.norm1(x)
y = self.norm1(x.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv1(y)
if temb is not None:
y = y + temb[:, None, None, :]
y = self.norm2(y)
y = self.norm2(y.astype(mx.float32)).astype(dtype)
y = nn.silu(y)
y = self.conv2(y)
@ -292,6 +295,23 @@ class UNetModel(nn.Module):
config.block_out_channels[0] * 4,
)
if config.addition_embed_type == "text_time":
self.add_time_proj = nn.SinusoidalPositionalEncoding(
config.addition_time_embed_dim,
max_freq=1,
min_freq=math.exp(
-math.log(10000)
+ 2 * math.log(10000) / config.addition_time_embed_dim
),
scale=1.0,
cos_first=True,
full_turns=False,
)
self.add_embedding = TimestepEmbedding(
config.projection_class_embeddings_input_dim,
config.block_out_channels[0] * 4,
)
# Make the downsampling blocks
block_channels = [config.block_out_channels[0]] + list(
config.block_out_channels
@ -308,7 +328,7 @@ class UNetModel(nn.Module):
resnet_groups=config.norm_num_groups,
add_downsample=(i < len(config.block_out_channels) - 1),
add_upsample=False,
add_cross_attention=(i < len(config.block_out_channels) - 1),
add_cross_attention="CrossAttn" in config.down_block_types[i],
)
for i, (in_channels, out_channels) in enumerate(
zip(block_channels, block_channels[1:])
@ -357,7 +377,7 @@ class UNetModel(nn.Module):
resnet_groups=config.norm_num_groups,
add_downsample=False,
add_upsample=(i > 0),
add_cross_attention=(i < len(config.block_out_channels) - 1),
add_cross_attention="CrossAttn" in config.up_block_types[i],
)
for i, (in_channels, out_channels, prev_out_channels) in reversed(
list(
@ -380,11 +400,27 @@ class UNetModel(nn.Module):
padding=(config.conv_out_kernel - 1) // 2,
)
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
def __call__(
self,
x,
timestep,
encoder_x,
attn_mask=None,
encoder_attn_mask=None,
text_time=None,
):
# Compute the time embeddings
temb = self.timesteps(timestep).astype(x.dtype)
temb = self.time_embedding(temb)
# Add the extra text_time conditioning
if text_time is not None:
text_emb, time_ids = text_time
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
emb = mx.concatenate([text_emb, emb], axis=-1)
emb = self.add_embedding(emb)
temb = temb + emb
# Preprocess the input
x = self.conv_in(x)
@ -417,7 +453,8 @@ class UNetModel(nn.Module):
)
# Postprocess the output
x = self.conv_norm_out(x)
dtype = x.dtype
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
x = nn.silu(x)
x = self.conv_out(x)

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.5 MiB

View File

@ -4,26 +4,50 @@ import argparse
import mlx.core as mx
import numpy as np
from mlx.nn import QuantizedLinear
from PIL import Image
from tqdm import tqdm
from stable_diffusion import StableDiffusion
from stable_diffusion import StableDiffusion, StableDiffusionXL
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl")
parser.add_argument("--n_images", type=int, default=4)
parser.add_argument("--steps", type=int, default=50)
parser.add_argument("--cfg", type=float, default=7.5)
parser.add_argument("--steps", type=int)
parser.add_argument("--cfg", type=float)
parser.add_argument("--negative_prompt", default="")
parser.add_argument("--n_rows", type=int, default=1)
parser.add_argument("--decoding_batch_size", type=int, default=1)
parser.add_argument("--no-float16", dest="float16", action="store_false")
parser.add_argument("--quantize", "-q", action="store_true")
parser.add_argument("--preload-models", action="store_true")
parser.add_argument("--output", default="out.png")
parser.add_argument("--verbose", "-v", action="store_true")
args = parser.parse_args()
sd = StableDiffusion()
if args.model == "sdxl":
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder_1)
QuantizedLinear.quantize_module(sd.text_encoder_2)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 0.0
args.steps = args.steps or 2
else:
sd = StableDiffusion(
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
)
if args.quantize:
QuantizedLinear.quantize_module(sd.text_encoder)
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
args.cfg = args.cfg or 7.5
args.steps = args.steps or 50
if args.preload_models:
sd.ensure_models_are_loaded()
# Generate the latent vectors using diffusion
latents = sd.generate_latents(
@ -36,11 +60,24 @@ if __name__ == "__main__":
for x_t in tqdm(latents, total=args.steps):
mx.eval(x_t)
# The following is not necessary but it may help in memory
# constrained systems by reusing the memory kept by the unet and the text
# encoders.
if args.model == "sdxl":
del sd.text_encoder_1
del sd.text_encoder_2
else:
del sd.text_encoder
del sd.unet
del sd.sampler
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
# Decode them into images
decoded = []
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
mx.eval(decoded[-1])
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
# Arrange them on a grid
x = mx.concatenate(decoded, axis=0)
@ -53,3 +90,8 @@ if __name__ == "__main__":
# Save them to disc
im = Image.fromarray(np.array(x))
im.save(args.output)
# Report the peak memory used during generation
if args.verbose:
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")