mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Stable diffusion XL (#516)
This commit is contained in:
parent
8c2cf665ed
commit
3a9e6c3f70
@ -2,22 +2,25 @@ Stable Diffusion
|
|||||||
================
|
================
|
||||||
|
|
||||||
Stable Diffusion in MLX. The implementation was ported from Hugging Face's
|
Stable Diffusion in MLX. The implementation was ported from Hugging Face's
|
||||||
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
|
[diffusers](https://huggingface.co/docs/diffusers/index) and model weights are
|
||||||
and using the weights available on the Hugging Face Hub by Stability AI at
|
downloaded directly from the Hugging Face hub. The implementation currently
|
||||||
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
|
supports the following models:
|
||||||
|
|
||||||
|
- [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo)
|
||||||
|
- [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1)
|
||||||
|
|
||||||

|

|
||||||
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign saying MLX in capital letters.'*
|
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign
|
||||||
|
saying MLX in capital letters.'*
|
||||||
|
|
||||||
Installation
|
Installation
|
||||||
------------
|
------------
|
||||||
|
|
||||||
The dependencies are minimal, namely:
|
The dependencies are minimal, namely:
|
||||||
|
|
||||||
- `safetensors` and `huggingface-hub` to load the checkpoints.
|
- `huggingface-hub` to download the checkpoints.
|
||||||
- `regex` for the tokenization
|
- `regex` for the tokenization
|
||||||
- `numpy` because safetensors needs to return some form of array
|
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
|
||||||
- `tqdm` and `PIL` for the `txt2image.py` script
|
|
||||||
|
|
||||||
You can install all of the above with the `requirements.txt` as follows:
|
You can install all of the above with the `requirements.txt` as follows:
|
||||||
|
|
||||||
@ -43,7 +46,9 @@ sd = StableDiffusion()
|
|||||||
#
|
#
|
||||||
# Because MLX is lazily evaluated iterating over this generator doesn't
|
# Because MLX is lazily evaluated iterating over this generator doesn't
|
||||||
# actually perform the computation until mx.eval() is called.
|
# actually perform the computation until mx.eval() is called.
|
||||||
latent_generator = sd.generate_latents("A photo of an astronaut riding a horse on Mars.")
|
latent_generator = sd.generate_latents(
|
||||||
|
"A photo of an astronaut riding a horse on Mars."
|
||||||
|
)
|
||||||
|
|
||||||
# Here we are evaluating each diffusion step but we could also evaluate
|
# Here we are evaluating each diffusion step but we could also evaluate
|
||||||
# once at the end.
|
# once at the end.
|
||||||
@ -55,10 +60,16 @@ for x_t in latent_generator:
|
|||||||
im = sd.decode(x_t)
|
im = sd.decode(x_t)
|
||||||
```
|
```
|
||||||
|
|
||||||
The above is almost line for line the implementation of the `txt2image.py`
|
The above is essentially the implementation of the `txt2image.py` script in the
|
||||||
script in the root of the repository. You can use the script as follows:
|
root of the repository. You can use the script as follows:
|
||||||
|
|
||||||
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
|
|
||||||
|
```shell
|
||||||
|
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
|
||||||
|
```
|
||||||
|
|
||||||
|
You can select the model using `--model` argument. Currently supported models
|
||||||
|
are `sdxl` (default) and `sd`.
|
||||||
|
|
||||||
Image 2 Image
|
Image 2 Image
|
||||||
-------------
|
-------------
|
||||||
@ -71,48 +82,36 @@ to the forward diffusion process and the `strength` parameter. A `strength` of
|
|||||||
random noise.
|
random noise.
|
||||||
|
|
||||||

|

|
||||||
|
|
||||||
*Generations with varying strength using the original image and the prompt 'A lit fireplace'.*
|
*Generations with varying strength using the original image and the prompt 'A lit fireplace'.*
|
||||||
|
|
||||||
The command to generate the above images is:
|
The command to generate the above images is:
|
||||||
|
|
||||||
python image2image.py --strength 0.5 original.png 'A lit fireplace'
|
```shell
|
||||||
|
python image2image.py --strength 0.5 original.png 'A lit fireplace'
|
||||||
|
```
|
||||||
|
|
||||||
*Note: `image2image.py` will automatically downsample your input image to guarantee that its dimensions are divisible by 64. If you want full control of this process, resize your image prior to using the script.*
|
> [!Note]
|
||||||
|
> `image2image.py` will automatically downsample your input image to guarantee
|
||||||
|
> that its dimensions are divisible by 64. If you want full control of this
|
||||||
|
> process, resize your image prior to using the script.
|
||||||
|
|
||||||
Performance
|
Memory constrained devices
|
||||||
-----------
|
--------------------------
|
||||||
|
|
||||||
The following table compares the performance of the UNet in stable diffusion.
|
The `txt2image.py` script by default loads the model in float16 which reduces
|
||||||
We report throughput in images per second **processed by the UNet** for the
|
significantly the required memory for image generation. However, since the
|
||||||
provided `txt2image.py` script and the `diffusers` library using the MPS
|
Stable Diffusion XL UNet alone has 2.6B parameters in order to use it in
|
||||||
PyTorch backend.
|
devices with 8GB of RAM, quantization is practically necessary.
|
||||||
|
|
||||||
At the time of writing this comparison convolutions are still some of the least
|
The `txt2image.py` script supports quantization using the `-q` or `--quantize`
|
||||||
optimized operations in MLX. Despite that, MLX still achieves **~40% higher
|
command line arguments. When quantization is used, the script quantizes the
|
||||||
throughput** than PyTorch with a batch size of 16 and ~15% higher when
|
text encoder models to 4 bits and the unet to 8 bits. This allows generating
|
||||||
comparing the optimal batch sizes.
|
images on an 8GB Mac Mini with no-swapping.
|
||||||
|
|
||||||
Notably, PyTorch achieves almost ~50% higher throughput for the batch size of 1
|
```
|
||||||
which is unfortunate as that means that a single image can be computed faster.
|
python txt2image.py --n_images 4 -q -v --output still-life.png "A painting of a vase on a wooden table, dark background, still life."
|
||||||
However, when starting with the models not loaded in memory and PyTorch's MPS
|
```
|
||||||
graph kernels not cached, the compilation time more than accounts for this
|
|
||||||
speed difference.
|
|
||||||
|
|
||||||
| Batch size | PyTorch | MLX |
|

|
||||||
| ---------- | ----------- | ----------- |
|
*Image generated using Stable Diffusion XL turbo in MLX with the above command on an 8GB M1 Mac mini*
|
||||||
| 1 | 6.25 im/s | 4.17 im/s |
|
|
||||||
| 2 | 7.14 im/s | 5.88 im/s |
|
|
||||||
| 4 |**7.69 im/s**| 7.14 im/s |
|
|
||||||
| 6 | 7.22 im/s | 8.00 im/s |
|
|
||||||
| 8 | 6.89 im/s | 8.42 im/s |
|
|
||||||
| 12 | 6.62 im/s | 8.51 im/s |
|
|
||||||
| 16 | 6.32 im/s |**8.79 im/s**|
|
|
||||||
|
|
||||||
The above experiments were made on an M2 Ultra with PyTorch version 2.1,
|
|
||||||
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
|
|
||||||
used classifier free guidance which means that the above batch sizes result
|
|
||||||
double the images processed by the UNet.
|
|
||||||
|
|
||||||
Note that the above table means that it takes about 90 seconds to fully
|
|
||||||
generate 16 images with MLX and 50 diffusion steps with classifier free
|
|
||||||
guidance and about 120 for PyTorch.
|
|
||||||
|
@ -22,10 +22,19 @@ if __name__ == "__main__":
|
|||||||
parser.add_argument("--negative_prompt", default="")
|
parser.add_argument("--negative_prompt", default="")
|
||||||
parser.add_argument("--n_rows", type=int, default=1)
|
parser.add_argument("--n_rows", type=int, default=1)
|
||||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||||
|
parser.add_argument("--quantize", "-q", action="store_true")
|
||||||
|
parser.add_argument("--no-float16", dest="float16", action="store_false")
|
||||||
|
parser.add_argument("--preload-models", action="store_true")
|
||||||
parser.add_argument("--output", default="out.png")
|
parser.add_argument("--output", default="out.png")
|
||||||
|
parser.add_argument("--verbose", "-v", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sd = StableDiffusion()
|
sd = StableDiffusion("stabilityai/stable-diffusion-2-1-base", float16=args.float16)
|
||||||
|
if args.quantize:
|
||||||
|
QuantizedLinear.quantize_module(sd.text_encoder)
|
||||||
|
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||||
|
if args.preload_models:
|
||||||
|
sd.ensure_models_are_loaded()
|
||||||
|
|
||||||
# Read the image
|
# Read the image
|
||||||
img = Image.open(args.image)
|
img = Image.open(args.image)
|
||||||
@ -52,11 +61,20 @@ if __name__ == "__main__":
|
|||||||
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
|
for x_t in tqdm(latents, total=int(args.steps * args.strength)):
|
||||||
mx.eval(x_t)
|
mx.eval(x_t)
|
||||||
|
|
||||||
|
# The following is not necessary but it may help in memory
|
||||||
|
# constrained systems by reusing the memory kept by the unet and the text
|
||||||
|
# encoders.
|
||||||
|
del sd.text_encoder
|
||||||
|
del sd.unet
|
||||||
|
del sd.sampler
|
||||||
|
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
||||||
|
|
||||||
# Decode them into images
|
# Decode them into images
|
||||||
decoded = []
|
decoded = []
|
||||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||||
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
||||||
mx.eval(decoded[-1])
|
mx.eval(decoded[-1])
|
||||||
|
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
||||||
|
|
||||||
# Arrange them on a grid
|
# Arrange them on a grid
|
||||||
x = mx.concatenate(decoded, axis=0)
|
x = mx.concatenate(decoded, axis=0)
|
||||||
@ -69,3 +87,8 @@ if __name__ == "__main__":
|
|||||||
# Save them to disc
|
# Save them to disc
|
||||||
im = Image.fromarray(np.array(x))
|
im = Image.fromarray(np.array(x))
|
||||||
im.save(args.output)
|
im.save(args.output)
|
||||||
|
|
||||||
|
# Report the peak memory used during generation
|
||||||
|
if args.verbose:
|
||||||
|
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
|
||||||
|
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
||||||
|
@ -1,5 +1,4 @@
|
|||||||
mlx>=0.1
|
mlx>=0.1
|
||||||
safetensors
|
|
||||||
huggingface-hub
|
huggingface-hub
|
||||||
regex
|
regex
|
||||||
numpy
|
numpy
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import time
|
import time
|
||||||
from typing import Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
|
|
||||||
@ -13,7 +13,7 @@ from .model_io import (
|
|||||||
load_tokenizer,
|
load_tokenizer,
|
||||||
load_unet,
|
load_unet,
|
||||||
)
|
)
|
||||||
from .sampler import SimpleEulerSampler
|
from .sampler import SimpleEulerAncestralSampler, SimpleEulerSampler
|
||||||
|
|
||||||
|
|
||||||
class StableDiffusion:
|
class StableDiffusion:
|
||||||
@ -22,10 +22,27 @@ class StableDiffusion:
|
|||||||
self.diffusion_config = load_diffusion_config(model)
|
self.diffusion_config = load_diffusion_config(model)
|
||||||
self.unet = load_unet(model, float16)
|
self.unet = load_unet(model, float16)
|
||||||
self.text_encoder = load_text_encoder(model, float16)
|
self.text_encoder = load_text_encoder(model, float16)
|
||||||
self.autoencoder = load_autoencoder(model, float16)
|
self.autoencoder = load_autoencoder(model, False)
|
||||||
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
self.sampler = SimpleEulerSampler(self.diffusion_config)
|
||||||
self.tokenizer = load_tokenizer(model)
|
self.tokenizer = load_tokenizer(model)
|
||||||
|
|
||||||
|
def ensure_models_are_loaded(self):
|
||||||
|
mx.eval(self.unet.parameters())
|
||||||
|
mx.eval(self.text_encoder.parameters())
|
||||||
|
mx.eval(self.autoencoder.parameters())
|
||||||
|
|
||||||
|
def _tokenize(self, tokenizer, text: str, negative_text: Optional[str] = None):
|
||||||
|
# Tokenize the text
|
||||||
|
tokens = [tokenizer.tokenize(text)]
|
||||||
|
if negative_text is not None:
|
||||||
|
tokens += [tokenizer.tokenize(negative_text)]
|
||||||
|
lengths = [len(t) for t in tokens]
|
||||||
|
N = max(lengths)
|
||||||
|
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
||||||
|
tokens = mx.array(tokens)
|
||||||
|
|
||||||
|
return tokens
|
||||||
|
|
||||||
def _get_text_conditioning(
|
def _get_text_conditioning(
|
||||||
self,
|
self,
|
||||||
text: str,
|
text: str,
|
||||||
@ -34,16 +51,12 @@ class StableDiffusion:
|
|||||||
negative_text: str = "",
|
negative_text: str = "",
|
||||||
):
|
):
|
||||||
# Tokenize the text
|
# Tokenize the text
|
||||||
tokens = [self.tokenizer.tokenize(text)]
|
tokens = self._tokenize(
|
||||||
if cfg_weight > 1:
|
self.tokenizer, text, (negative_text if cfg_weight > 1 else None)
|
||||||
tokens += [self.tokenizer.tokenize(negative_text)]
|
)
|
||||||
lengths = [len(t) for t in tokens]
|
|
||||||
N = max(lengths)
|
|
||||||
tokens = [t + [0] * (N - len(t)) for t in tokens]
|
|
||||||
tokens = mx.array(tokens)
|
|
||||||
|
|
||||||
# Compute the features
|
# Compute the features
|
||||||
conditioning = self.text_encoder(tokens)
|
conditioning = self.text_encoder(tokens).last_hidden_state
|
||||||
|
|
||||||
# Repeat the conditioning for each of the generated images
|
# Repeat the conditioning for each of the generated images
|
||||||
if n_images > 1:
|
if n_images > 1:
|
||||||
@ -51,10 +64,14 @@ class StableDiffusion:
|
|||||||
|
|
||||||
return conditioning
|
return conditioning
|
||||||
|
|
||||||
def _denoising_step(self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5):
|
def _denoising_step(
|
||||||
|
self, x_t, t, t_prev, conditioning, cfg_weight: float = 7.5, text_time=None
|
||||||
|
):
|
||||||
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
x_t_unet = mx.concatenate([x_t] * 2, axis=0) if cfg_weight > 1 else x_t
|
||||||
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
t_unet = mx.broadcast_to(t, [len(x_t_unet)])
|
||||||
eps_pred = self.unet(x_t_unet, t_unet, encoder_x=conditioning)
|
eps_pred = self.unet(
|
||||||
|
x_t_unet, t_unet, encoder_x=conditioning, text_time=text_time
|
||||||
|
)
|
||||||
|
|
||||||
if cfg_weight > 1:
|
if cfg_weight > 1:
|
||||||
eps_text, eps_neg = eps_pred.split(2)
|
eps_text, eps_neg = eps_pred.split(2)
|
||||||
@ -65,13 +82,21 @@ class StableDiffusion:
|
|||||||
return x_t_prev
|
return x_t_prev
|
||||||
|
|
||||||
def _denoising_loop(
|
def _denoising_loop(
|
||||||
self, x_T, T, conditioning, num_steps: int = 50, cfg_weight: float = 7.5
|
self,
|
||||||
|
x_T,
|
||||||
|
T,
|
||||||
|
conditioning,
|
||||||
|
num_steps: int = 50,
|
||||||
|
cfg_weight: float = 7.5,
|
||||||
|
text_time=None,
|
||||||
):
|
):
|
||||||
x_t = x_T
|
x_t = x_T
|
||||||
for t, t_prev in self.sampler.timesteps(
|
for t, t_prev in self.sampler.timesteps(
|
||||||
num_steps, start_time=T, dtype=self.dtype
|
num_steps, start_time=T, dtype=self.dtype
|
||||||
):
|
):
|
||||||
x_t = self._denoising_step(x_t, t, t_prev, conditioning, cfg_weight)
|
x_t = self._denoising_step(
|
||||||
|
x_t, t, t_prev, conditioning, cfg_weight, text_time
|
||||||
|
)
|
||||||
yield x_t
|
yield x_t
|
||||||
|
|
||||||
def generate_latents(
|
def generate_latents(
|
||||||
@ -142,3 +167,100 @@ class StableDiffusion:
|
|||||||
x = self.autoencoder.decode(x_t)
|
x = self.autoencoder.decode(x_t)
|
||||||
x = mx.clip(x / 2 + 0.5, 0, 1)
|
x = mx.clip(x / 2 + 0.5, 0, 1)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class StableDiffusionXL(StableDiffusion):
|
||||||
|
def __init__(self, model: str = _DEFAULT_MODEL, float16: bool = False):
|
||||||
|
super().__init__(model, float16)
|
||||||
|
|
||||||
|
self.sampler = SimpleEulerAncestralSampler(self.diffusion_config)
|
||||||
|
|
||||||
|
self.text_encoder_1 = self.text_encoder
|
||||||
|
self.tokenizer_1 = self.tokenizer
|
||||||
|
del self.tokenizer, self.text_encoder
|
||||||
|
|
||||||
|
self.text_encoder_2 = load_text_encoder(
|
||||||
|
model,
|
||||||
|
float16,
|
||||||
|
model_key="text_encoder_2",
|
||||||
|
)
|
||||||
|
self.tokenizer_2 = load_tokenizer(
|
||||||
|
model,
|
||||||
|
merges_key="tokenizer_2_merges",
|
||||||
|
vocab_key="tokenizer_2_vocab",
|
||||||
|
)
|
||||||
|
|
||||||
|
def ensure_models_are_loaded(self):
|
||||||
|
mx.eval(self.unet.parameters())
|
||||||
|
mx.eval(self.text_encoder_1.parameters())
|
||||||
|
mx.eval(self.text_encoder_2.parameters())
|
||||||
|
mx.eval(self.autoencoder.parameters())
|
||||||
|
|
||||||
|
def _get_text_conditioning(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
n_images: int = 1,
|
||||||
|
cfg_weight: float = 7.5,
|
||||||
|
negative_text: str = "",
|
||||||
|
):
|
||||||
|
tokens_1 = self._tokenize(
|
||||||
|
self.tokenizer_1,
|
||||||
|
text,
|
||||||
|
(negative_text if cfg_weight > 1 else None),
|
||||||
|
)
|
||||||
|
tokens_2 = self._tokenize(
|
||||||
|
self.tokenizer_2,
|
||||||
|
text,
|
||||||
|
(negative_text if cfg_weight > 1 else None),
|
||||||
|
)
|
||||||
|
|
||||||
|
conditioning_1 = self.text_encoder_1(tokens_1)
|
||||||
|
conditioning_2 = self.text_encoder_2(tokens_2)
|
||||||
|
conditioning = mx.concatenate(
|
||||||
|
[conditioning_1.hidden_states[-2], conditioning_2.hidden_states[-2]],
|
||||||
|
axis=-1,
|
||||||
|
)
|
||||||
|
pooled_conditioning = conditioning_2.pooled_output
|
||||||
|
|
||||||
|
if n_images > 1:
|
||||||
|
conditioning = mx.repeat(conditioning, n_images, axis=0)
|
||||||
|
|
||||||
|
return conditioning, pooled_conditioning
|
||||||
|
|
||||||
|
def generate_latents(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
n_images: int = 1,
|
||||||
|
num_steps: int = 2,
|
||||||
|
cfg_weight: float = 0.0,
|
||||||
|
negative_text: str = "",
|
||||||
|
latent_size: Tuple[int] = (64, 64),
|
||||||
|
seed=None,
|
||||||
|
):
|
||||||
|
# Set the PRNG state
|
||||||
|
seed = seed or int(time.time())
|
||||||
|
mx.random.seed(seed)
|
||||||
|
|
||||||
|
# Get the text conditioning
|
||||||
|
conditioning, pooled_conditioning = self._get_text_conditioning(
|
||||||
|
text, n_images, cfg_weight, negative_text
|
||||||
|
)
|
||||||
|
text_time = (
|
||||||
|
pooled_conditioning,
|
||||||
|
mx.array([[512, 512, 0, 0, 512, 512.0]] * len(pooled_conditioning)),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create the latent variables
|
||||||
|
x_T = self.sampler.sample_prior(
|
||||||
|
(n_images, *latent_size, self.autoencoder.latent_channels), dtype=self.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# Perform the denoising loop
|
||||||
|
yield from self._denoising_loop(
|
||||||
|
x_T,
|
||||||
|
self.sampler.max_time,
|
||||||
|
conditioning,
|
||||||
|
num_steps,
|
||||||
|
cfg_weight,
|
||||||
|
text_time=text_time,
|
||||||
|
)
|
||||||
|
@ -1,15 +1,33 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import mlx.nn as nn
|
import mlx.nn as nn
|
||||||
|
|
||||||
from .config import CLIPTextModelConfig
|
from .config import CLIPTextModelConfig
|
||||||
|
|
||||||
|
_ACTIVATIONS = {"quick_gelu": nn.gelu_fast_approx, "gelu": nn.gelu}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CLIPOutput:
|
||||||
|
# The last_hidden_state indexed at the EOS token and possibly projected if
|
||||||
|
# the model has a projection layer
|
||||||
|
pooled_output: Optional[mx.array] = None
|
||||||
|
|
||||||
|
# The full sequence output of the transformer after the final layernorm
|
||||||
|
last_hidden_state: Optional[mx.array] = None
|
||||||
|
|
||||||
|
# A list of hidden states corresponding to the outputs of the transformer layers
|
||||||
|
hidden_states: Optional[List[mx.array]] = None
|
||||||
|
|
||||||
|
|
||||||
class CLIPEncoderLayer(nn.Module):
|
class CLIPEncoderLayer(nn.Module):
|
||||||
"""The transformer encoder layer from CLIP."""
|
"""The transformer encoder layer from CLIP."""
|
||||||
|
|
||||||
def __init__(self, model_dims: int, num_heads: int):
|
def __init__(self, model_dims: int, num_heads: int, activation: str):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.layer_norm1 = nn.LayerNorm(model_dims)
|
self.layer_norm1 = nn.LayerNorm(model_dims)
|
||||||
@ -25,6 +43,8 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
self.linear1 = nn.Linear(model_dims, 4 * model_dims)
|
||||||
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
self.linear2 = nn.Linear(4 * model_dims, model_dims)
|
||||||
|
|
||||||
|
self.act = _ACTIVATIONS[activation]
|
||||||
|
|
||||||
def __call__(self, x, attn_mask=None):
|
def __call__(self, x, attn_mask=None):
|
||||||
y = self.layer_norm1(x)
|
y = self.layer_norm1(x)
|
||||||
y = self.attention(y, y, y, attn_mask)
|
y = self.attention(y, y, y, attn_mask)
|
||||||
@ -32,7 +52,7 @@ class CLIPEncoderLayer(nn.Module):
|
|||||||
|
|
||||||
y = self.layer_norm2(x)
|
y = self.layer_norm2(x)
|
||||||
y = self.linear1(y)
|
y = self.linear1(y)
|
||||||
y = nn.gelu_approx(y)
|
y = self.act(y)
|
||||||
y = self.linear2(y)
|
y = self.linear2(y)
|
||||||
x = y + x
|
x = y + x
|
||||||
|
|
||||||
@ -48,23 +68,49 @@ class CLIPTextModel(nn.Module):
|
|||||||
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
self.token_embedding = nn.Embedding(config.vocab_size, config.model_dims)
|
||||||
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
self.position_embedding = nn.Embedding(config.max_length, config.model_dims)
|
||||||
self.layers = [
|
self.layers = [
|
||||||
CLIPEncoderLayer(config.model_dims, config.num_heads)
|
CLIPEncoderLayer(config.model_dims, config.num_heads, config.hidden_act)
|
||||||
for i in range(config.num_layers)
|
for i in range(config.num_layers)
|
||||||
]
|
]
|
||||||
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
self.final_layer_norm = nn.LayerNorm(config.model_dims)
|
||||||
|
|
||||||
|
if config.projection_dim is not None:
|
||||||
|
self.text_projection = nn.Linear(
|
||||||
|
config.model_dims, config.projection_dim, bias=False
|
||||||
|
)
|
||||||
|
|
||||||
|
def _get_mask(self, N, dtype):
|
||||||
|
indices = mx.arange(N)
|
||||||
|
mask = indices[:, None] < indices[None]
|
||||||
|
mask = mask.astype(dtype) * (-6e4 if dtype == mx.float16 else -1e9)
|
||||||
|
return mask
|
||||||
|
|
||||||
def __call__(self, x):
|
def __call__(self, x):
|
||||||
# Extract some shapes
|
# Extract some shapes
|
||||||
B, N = x.shape
|
B, N = x.shape
|
||||||
|
eos_tokens = x.argmax(-1)
|
||||||
|
|
||||||
# Compute the embeddings
|
# Compute the embeddings
|
||||||
x = self.token_embedding(x)
|
x = self.token_embedding(x)
|
||||||
x = x + self.position_embedding.weight[:N]
|
x = x + self.position_embedding.weight[:N]
|
||||||
|
|
||||||
# Compute the features from the transformer
|
# Compute the features from the transformer
|
||||||
mask = nn.MultiHeadAttention.create_additive_causal_mask(N, x.dtype)
|
mask = self._get_mask(N, x.dtype)
|
||||||
|
hidden_states = []
|
||||||
for l in self.layers:
|
for l in self.layers:
|
||||||
x = l(x, mask)
|
x = l(x, mask)
|
||||||
|
hidden_states.append(x)
|
||||||
|
|
||||||
# Apply the final layernorm and return
|
# Apply the final layernorm and return
|
||||||
return self.final_layer_norm(x)
|
x = self.final_layer_norm(x)
|
||||||
|
last_hidden_state = x
|
||||||
|
|
||||||
|
# Select the EOS token
|
||||||
|
pooled_output = x[mx.arange(len(x)), eos_tokens]
|
||||||
|
if "text_projection" in self:
|
||||||
|
pooled_output = self.text_projection(pooled_output)
|
||||||
|
|
||||||
|
return CLIPOutput(
|
||||||
|
pooled_output=pooled_output,
|
||||||
|
last_hidden_state=last_hidden_state,
|
||||||
|
hidden_states=hidden_states,
|
||||||
|
)
|
||||||
|
@ -1,4 +1,4 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
@ -23,6 +23,8 @@ class CLIPTextModelConfig:
|
|||||||
num_heads: int = 16
|
num_heads: int = 16
|
||||||
max_length: int = 77
|
max_length: int = 77
|
||||||
vocab_size: int = 49408
|
vocab_size: int = 49408
|
||||||
|
projection_dim: Optional[int] = None
|
||||||
|
hidden_act: str = "quick_gelu"
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -38,6 +40,21 @@ class UNetConfig:
|
|||||||
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
num_attention_heads: Tuple[int] = (5, 10, 20, 20)
|
||||||
cross_attention_dim: Tuple[int] = (1024,) * 4
|
cross_attention_dim: Tuple[int] = (1024,) * 4
|
||||||
norm_num_groups: int = 32
|
norm_num_groups: int = 32
|
||||||
|
down_block_types: Tuple[str] = (
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"CrossAttnDownBlock2D",
|
||||||
|
"DownBlock2D",
|
||||||
|
)
|
||||||
|
up_block_types: Tuple[str] = (
|
||||||
|
"UpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
"CrossAttnUpBlock2D",
|
||||||
|
)
|
||||||
|
addition_embed_type: Optional[str] = None
|
||||||
|
addition_time_embed_dim: Optional[int] = None
|
||||||
|
projection_class_embeddings_input_dim: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -1,13 +1,12 @@
|
|||||||
# Copyright © 2023 Apple Inc.
|
# Copyright © 2023-2024 Apple Inc.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from mlx.utils import tree_unflatten
|
from mlx.utils import tree_unflatten
|
||||||
from safetensors import safe_open as safetensor_open
|
|
||||||
|
|
||||||
from .clip import CLIPTextModel
|
from .clip import CLIPTextModel
|
||||||
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
from .config import AutoencoderConfig, CLIPTextModelConfig, DiffusionConfig, UNetConfig
|
||||||
@ -17,6 +16,22 @@ from .vae import Autoencoder
|
|||||||
|
|
||||||
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
_DEFAULT_MODEL = "stabilityai/stable-diffusion-2-1-base"
|
||||||
_MODELS = {
|
_MODELS = {
|
||||||
|
# See https://huggingface.co/stabilityai/sdxl-turbo for the model details and license
|
||||||
|
"stabilityai/sdxl-turbo": {
|
||||||
|
"unet_config": "unet/config.json",
|
||||||
|
"unet": "unet/diffusion_pytorch_model.safetensors",
|
||||||
|
"text_encoder_config": "text_encoder/config.json",
|
||||||
|
"text_encoder": "text_encoder/model.safetensors",
|
||||||
|
"text_encoder_2_config": "text_encoder_2/config.json",
|
||||||
|
"text_encoder_2": "text_encoder_2/model.safetensors",
|
||||||
|
"vae_config": "vae/config.json",
|
||||||
|
"vae": "vae/diffusion_pytorch_model.safetensors",
|
||||||
|
"diffusion_config": "scheduler/scheduler_config.json",
|
||||||
|
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||||
|
"tokenizer_merges": "tokenizer/merges.txt",
|
||||||
|
"tokenizer_2_vocab": "tokenizer_2/vocab.json",
|
||||||
|
"tokenizer_2_merges": "tokenizer_2/merges.txt",
|
||||||
|
},
|
||||||
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
# See https://huggingface.co/stabilityai/stable-diffusion-2-1-base for the model details and license
|
||||||
"stabilityai/stable-diffusion-2-1-base": {
|
"stabilityai/stable-diffusion-2-1-base": {
|
||||||
"unet_config": "unet/config.json",
|
"unet_config": "unet/config.json",
|
||||||
@ -28,14 +43,10 @@ _MODELS = {
|
|||||||
"diffusion_config": "scheduler/scheduler_config.json",
|
"diffusion_config": "scheduler/scheduler_config.json",
|
||||||
"tokenizer_vocab": "tokenizer/vocab.json",
|
"tokenizer_vocab": "tokenizer/vocab.json",
|
||||||
"tokenizer_merges": "tokenizer/merges.txt",
|
"tokenizer_merges": "tokenizer/merges.txt",
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _from_numpy(x):
|
|
||||||
return mx.array(np.ascontiguousarray(x))
|
|
||||||
|
|
||||||
|
|
||||||
def map_unet_weights(key, value):
|
def map_unet_weights(key, value):
|
||||||
# Map up/downsampling
|
# Map up/downsampling
|
||||||
if "downsamplers" in key:
|
if "downsamplers" in key:
|
||||||
@ -67,9 +78,9 @@ def map_unet_weights(key, value):
|
|||||||
if "ff.net.0" in key:
|
if "ff.net.0" in key:
|
||||||
k1 = key.replace("ff.net.0.proj", "linear1")
|
k1 = key.replace("ff.net.0.proj", "linear1")
|
||||||
k2 = key.replace("ff.net.0.proj", "linear2")
|
k2 = key.replace("ff.net.0.proj", "linear2")
|
||||||
v1, v2 = np.split(value, 2)
|
v1, v2 = mx.split(value, 2)
|
||||||
|
|
||||||
return [(k1, _from_numpy(v1)), (k2, _from_numpy(v2))]
|
return [(k1, v1), (k2, v2)]
|
||||||
|
|
||||||
if "conv_shortcut.weight" in key:
|
if "conv_shortcut.weight" in key:
|
||||||
value = value.squeeze()
|
value = value.squeeze()
|
||||||
@ -80,8 +91,9 @@ def map_unet_weights(key, value):
|
|||||||
|
|
||||||
if len(value.shape) == 4:
|
if len(value.shape) == 4:
|
||||||
value = value.transpose(0, 2, 3, 1)
|
value = value.transpose(0, 2, 3, 1)
|
||||||
|
value = value.reshape(-1).reshape(value.shape)
|
||||||
|
|
||||||
return [(key, _from_numpy(value))]
|
return [(key, value)]
|
||||||
|
|
||||||
|
|
||||||
def map_clip_text_encoder_weights(key, value):
|
def map_clip_text_encoder_weights(key, value):
|
||||||
@ -109,7 +121,7 @@ def map_clip_text_encoder_weights(key, value):
|
|||||||
if "mlp.fc2" in key:
|
if "mlp.fc2" in key:
|
||||||
key = key.replace("mlp.fc2", "linear2")
|
key = key.replace("mlp.fc2", "linear2")
|
||||||
|
|
||||||
return [(key, _from_numpy(value))]
|
return [(key, value)]
|
||||||
|
|
||||||
|
|
||||||
def map_vae_weights(key, value):
|
def map_vae_weights(key, value):
|
||||||
@ -148,8 +160,9 @@ def map_vae_weights(key, value):
|
|||||||
|
|
||||||
if len(value.shape) == 4:
|
if len(value.shape) == 4:
|
||||||
value = value.transpose(0, 2, 3, 1)
|
value = value.transpose(0, 2, 3, 1)
|
||||||
|
value = value.reshape(-1).reshape(value.shape)
|
||||||
|
|
||||||
return [(key, _from_numpy(value))]
|
return [(key, value)]
|
||||||
|
|
||||||
|
|
||||||
def _flatten(params):
|
def _flatten(params):
|
||||||
@ -157,9 +170,9 @@ def _flatten(params):
|
|||||||
|
|
||||||
|
|
||||||
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
def _load_safetensor_weights(mapper, model, weight_file, float16: bool = False):
|
||||||
dtype = np.float16 if float16 else np.float32
|
dtype = mx.float16 if float16 else mx.float32
|
||||||
with safetensor_open(weight_file, framework="numpy") as f:
|
weights = mx.load(weight_file)
|
||||||
weights = _flatten([mapper(k, f.get_tensor(k).astype(dtype)) for k in f.keys()])
|
weights = _flatten([mapper(k, v.astype(dtype)) for k, v in weights.items()])
|
||||||
model.update(tree_unflatten(weights))
|
model.update(tree_unflatten(weights))
|
||||||
|
|
||||||
|
|
||||||
@ -186,6 +199,9 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|||||||
out_channels=config["out_channels"],
|
out_channels=config["out_channels"],
|
||||||
block_out_channels=config["block_out_channels"],
|
block_out_channels=config["block_out_channels"],
|
||||||
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
layers_per_block=[config["layers_per_block"]] * n_blocks,
|
||||||
|
transformer_layers_per_block=config.get(
|
||||||
|
"transformer_layers_per_block", (1,) * 4
|
||||||
|
),
|
||||||
num_attention_heads=(
|
num_attention_heads=(
|
||||||
[config["attention_head_dim"]] * n_blocks
|
[config["attention_head_dim"]] * n_blocks
|
||||||
if isinstance(config["attention_head_dim"], int)
|
if isinstance(config["attention_head_dim"], int)
|
||||||
@ -193,6 +209,13 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|||||||
),
|
),
|
||||||
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
cross_attention_dim=[config["cross_attention_dim"]] * n_blocks,
|
||||||
norm_num_groups=config["norm_num_groups"],
|
norm_num_groups=config["norm_num_groups"],
|
||||||
|
down_block_types=config["down_block_types"],
|
||||||
|
up_block_types=config["up_block_types"][::-1],
|
||||||
|
addition_embed_type=config.get("addition_embed_type", None),
|
||||||
|
addition_time_embed_dim=config.get("addition_time_embed_dim", None),
|
||||||
|
projection_class_embeddings_input_dim=config.get(
|
||||||
|
"projection_class_embeddings_input_dim", None
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -204,15 +227,24 @@ def load_unet(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
def load_text_encoder(
|
||||||
|
key: str = _DEFAULT_MODEL,
|
||||||
|
float16: bool = False,
|
||||||
|
model_key: str = "text_encoder",
|
||||||
|
config_key: Optional[str] = None,
|
||||||
|
):
|
||||||
"""Load the stable diffusion text encoder from Hugging Face Hub."""
|
"""Load the stable diffusion text encoder from Hugging Face Hub."""
|
||||||
_check_key(key, "load_text_encoder")
|
_check_key(key, "load_text_encoder")
|
||||||
|
|
||||||
|
config_key = config_key or (model_key + "_config")
|
||||||
|
|
||||||
# Download the config and create the model
|
# Download the config and create the model
|
||||||
text_encoder_config = _MODELS[key]["text_encoder_config"]
|
text_encoder_config = _MODELS[key][config_key]
|
||||||
with open(hf_hub_download(key, text_encoder_config)) as f:
|
with open(hf_hub_download(key, text_encoder_config)) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
|
|
||||||
|
with_projection = "WithProjection" in config["architectures"][0]
|
||||||
|
|
||||||
model = CLIPTextModel(
|
model = CLIPTextModel(
|
||||||
CLIPTextModelConfig(
|
CLIPTextModelConfig(
|
||||||
num_layers=config["num_hidden_layers"],
|
num_layers=config["num_hidden_layers"],
|
||||||
@ -220,11 +252,13 @@ def load_text_encoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|||||||
num_heads=config["num_attention_heads"],
|
num_heads=config["num_attention_heads"],
|
||||||
max_length=config["max_position_embeddings"],
|
max_length=config["max_position_embeddings"],
|
||||||
vocab_size=config["vocab_size"],
|
vocab_size=config["vocab_size"],
|
||||||
|
projection_dim=config["projection_dim"] if with_projection else None,
|
||||||
|
hidden_act=config.get("hidden_act", "quick_gelu"),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Download the weights and map them into the model
|
# Download the weights and map them into the model
|
||||||
text_encoder_weights = _MODELS[key]["text_encoder"]
|
text_encoder_weights = _MODELS[key][model_key]
|
||||||
weight_file = hf_hub_download(key, text_encoder_weights)
|
weight_file = hf_hub_download(key, text_encoder_weights)
|
||||||
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
_load_safetensor_weights(map_clip_text_encoder_weights, model, weight_file, float16)
|
||||||
|
|
||||||
@ -249,6 +283,7 @@ def load_autoencoder(key: str = _DEFAULT_MODEL, float16: bool = False):
|
|||||||
block_out_channels=config["block_out_channels"],
|
block_out_channels=config["block_out_channels"],
|
||||||
layers_per_block=config["layers_per_block"],
|
layers_per_block=config["layers_per_block"],
|
||||||
norm_num_groups=config["norm_num_groups"],
|
norm_num_groups=config["norm_num_groups"],
|
||||||
|
scaling_factor=config.get("scaling_factor", 0.18215),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -276,14 +311,18 @@ def load_diffusion_config(key: str = _DEFAULT_MODEL):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def load_tokenizer(key: str = _DEFAULT_MODEL):
|
def load_tokenizer(
|
||||||
|
key: str = _DEFAULT_MODEL,
|
||||||
|
vocab_key: str = "tokenizer_vocab",
|
||||||
|
merges_key: str = "tokenizer_merges",
|
||||||
|
):
|
||||||
_check_key(key, "load_tokenizer")
|
_check_key(key, "load_tokenizer")
|
||||||
|
|
||||||
vocab_file = hf_hub_download(key, _MODELS[key]["tokenizer_vocab"])
|
vocab_file = hf_hub_download(key, _MODELS[key][vocab_key])
|
||||||
with open(vocab_file, encoding="utf-8") as f:
|
with open(vocab_file, encoding="utf-8") as f:
|
||||||
vocab = json.load(f)
|
vocab = json.load(f)
|
||||||
|
|
||||||
merges_file = hf_hub_download(key, _MODELS[key]["tokenizer_merges"])
|
merges_file = hf_hub_download(key, _MODELS[key][merges_key])
|
||||||
with open(merges_file, encoding="utf-8") as f:
|
with open(merges_file, encoding="utf-8") as f:
|
||||||
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
bpe_merges = f.read().strip().split("\n")[1 : 49152 - 256 - 2 + 1]
|
||||||
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
bpe_merges = [tuple(m.split()) for m in bpe_merges]
|
||||||
|
@ -83,3 +83,23 @@ class SimpleEulerSampler:
|
|||||||
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
x_t_prev = x_t_prev * (sigma_prev.square() + 1).rsqrt()
|
||||||
|
|
||||||
return x_t_prev
|
return x_t_prev
|
||||||
|
|
||||||
|
|
||||||
|
class SimpleEulerAncestralSampler(SimpleEulerSampler):
|
||||||
|
def step(self, eps_pred, x_t, t, t_prev):
|
||||||
|
sigma = self.sigmas(t).astype(eps_pred.dtype)
|
||||||
|
sigma_prev = self.sigmas(t_prev).astype(eps_pred.dtype)
|
||||||
|
|
||||||
|
sigma2 = sigma.square()
|
||||||
|
sigma_prev2 = sigma_prev.square()
|
||||||
|
sigma_up = (sigma_prev2 * (sigma2 - sigma_prev2) / sigma2).sqrt()
|
||||||
|
sigma_down = (sigma_prev2 - sigma_up**2).sqrt()
|
||||||
|
|
||||||
|
dt = sigma_down - sigma
|
||||||
|
x_t_prev = (sigma2 + 1).sqrt() * x_t + eps_pred * dt
|
||||||
|
noise = mx.random.normal(x_t_prev.shape).astype(x_t_prev.dtype)
|
||||||
|
x_t_prev = x_t_prev + noise * sigma_up
|
||||||
|
|
||||||
|
x_t_prev = x_t_prev * (sigma_prev2 + 1).rsqrt()
|
||||||
|
|
||||||
|
return x_t_prev
|
||||||
|
@ -74,7 +74,7 @@ class TransformerBlock(nn.Module):
|
|||||||
y = self.norm3(x)
|
y = self.norm3(x)
|
||||||
y_a = self.linear1(y)
|
y_a = self.linear1(y)
|
||||||
y_b = self.linear2(y)
|
y_b = self.linear2(y)
|
||||||
y = y_a * nn.gelu_approx(y_b) # approximate gelu?
|
y = y_a * nn.gelu(y_b)
|
||||||
y = self.linear3(y)
|
y = self.linear3(y)
|
||||||
x = x + y
|
x = x + y
|
||||||
|
|
||||||
@ -106,10 +106,11 @@ class Transformer2D(nn.Module):
|
|||||||
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
def __call__(self, x, encoder_x, attn_mask, encoder_attn_mask):
|
||||||
# Save the input to add to the output
|
# Save the input to add to the output
|
||||||
input_x = x
|
input_x = x
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
# Perform the input norm and projection
|
# Perform the input norm and projection
|
||||||
B, H, W, C = x.shape
|
B, H, W, C = x.shape
|
||||||
x = self.norm(x).reshape(B, -1, C)
|
x = self.norm(x.astype(mx.float32)).astype(dtype).reshape(B, -1, C)
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
|
|
||||||
# Apply the transformer
|
# Apply the transformer
|
||||||
@ -150,15 +151,17 @@ class ResnetBlock2D(nn.Module):
|
|||||||
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
self.conv_shortcut = nn.Linear(in_channels, out_channels)
|
||||||
|
|
||||||
def __call__(self, x, temb=None):
|
def __call__(self, x, temb=None):
|
||||||
|
dtype = x.dtype
|
||||||
|
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
temb = self.time_emb_proj(nn.silu(temb))
|
temb = self.time_emb_proj(nn.silu(temb))
|
||||||
|
|
||||||
y = self.norm1(x)
|
y = self.norm1(x.astype(mx.float32)).astype(dtype)
|
||||||
y = nn.silu(y)
|
y = nn.silu(y)
|
||||||
y = self.conv1(y)
|
y = self.conv1(y)
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
y = y + temb[:, None, None, :]
|
y = y + temb[:, None, None, :]
|
||||||
y = self.norm2(y)
|
y = self.norm2(y.astype(mx.float32)).astype(dtype)
|
||||||
y = nn.silu(y)
|
y = nn.silu(y)
|
||||||
y = self.conv2(y)
|
y = self.conv2(y)
|
||||||
|
|
||||||
@ -292,6 +295,23 @@ class UNetModel(nn.Module):
|
|||||||
config.block_out_channels[0] * 4,
|
config.block_out_channels[0] * 4,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.addition_embed_type == "text_time":
|
||||||
|
self.add_time_proj = nn.SinusoidalPositionalEncoding(
|
||||||
|
config.addition_time_embed_dim,
|
||||||
|
max_freq=1,
|
||||||
|
min_freq=math.exp(
|
||||||
|
-math.log(10000)
|
||||||
|
+ 2 * math.log(10000) / config.addition_time_embed_dim
|
||||||
|
),
|
||||||
|
scale=1.0,
|
||||||
|
cos_first=True,
|
||||||
|
full_turns=False,
|
||||||
|
)
|
||||||
|
self.add_embedding = TimestepEmbedding(
|
||||||
|
config.projection_class_embeddings_input_dim,
|
||||||
|
config.block_out_channels[0] * 4,
|
||||||
|
)
|
||||||
|
|
||||||
# Make the downsampling blocks
|
# Make the downsampling blocks
|
||||||
block_channels = [config.block_out_channels[0]] + list(
|
block_channels = [config.block_out_channels[0]] + list(
|
||||||
config.block_out_channels
|
config.block_out_channels
|
||||||
@ -308,7 +328,7 @@ class UNetModel(nn.Module):
|
|||||||
resnet_groups=config.norm_num_groups,
|
resnet_groups=config.norm_num_groups,
|
||||||
add_downsample=(i < len(config.block_out_channels) - 1),
|
add_downsample=(i < len(config.block_out_channels) - 1),
|
||||||
add_upsample=False,
|
add_upsample=False,
|
||||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
add_cross_attention="CrossAttn" in config.down_block_types[i],
|
||||||
)
|
)
|
||||||
for i, (in_channels, out_channels) in enumerate(
|
for i, (in_channels, out_channels) in enumerate(
|
||||||
zip(block_channels, block_channels[1:])
|
zip(block_channels, block_channels[1:])
|
||||||
@ -357,7 +377,7 @@ class UNetModel(nn.Module):
|
|||||||
resnet_groups=config.norm_num_groups,
|
resnet_groups=config.norm_num_groups,
|
||||||
add_downsample=False,
|
add_downsample=False,
|
||||||
add_upsample=(i > 0),
|
add_upsample=(i > 0),
|
||||||
add_cross_attention=(i < len(config.block_out_channels) - 1),
|
add_cross_attention="CrossAttn" in config.up_block_types[i],
|
||||||
)
|
)
|
||||||
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
for i, (in_channels, out_channels, prev_out_channels) in reversed(
|
||||||
list(
|
list(
|
||||||
@ -380,11 +400,27 @@ class UNetModel(nn.Module):
|
|||||||
padding=(config.conv_out_kernel - 1) // 2,
|
padding=(config.conv_out_kernel - 1) // 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self, x, timestep, encoder_x, attn_mask=None, encoder_attn_mask=None):
|
def __call__(
|
||||||
|
self,
|
||||||
|
x,
|
||||||
|
timestep,
|
||||||
|
encoder_x,
|
||||||
|
attn_mask=None,
|
||||||
|
encoder_attn_mask=None,
|
||||||
|
text_time=None,
|
||||||
|
):
|
||||||
# Compute the time embeddings
|
# Compute the time embeddings
|
||||||
temb = self.timesteps(timestep).astype(x.dtype)
|
temb = self.timesteps(timestep).astype(x.dtype)
|
||||||
temb = self.time_embedding(temb)
|
temb = self.time_embedding(temb)
|
||||||
|
|
||||||
|
# Add the extra text_time conditioning
|
||||||
|
if text_time is not None:
|
||||||
|
text_emb, time_ids = text_time
|
||||||
|
emb = self.add_time_proj(time_ids).flatten(1).astype(x.dtype)
|
||||||
|
emb = mx.concatenate([text_emb, emb], axis=-1)
|
||||||
|
emb = self.add_embedding(emb)
|
||||||
|
temb = temb + emb
|
||||||
|
|
||||||
# Preprocess the input
|
# Preprocess the input
|
||||||
x = self.conv_in(x)
|
x = self.conv_in(x)
|
||||||
|
|
||||||
@ -417,7 +453,8 @@ class UNetModel(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Postprocess the output
|
# Postprocess the output
|
||||||
x = self.conv_norm_out(x)
|
dtype = x.dtype
|
||||||
|
x = self.conv_norm_out(x.astype(mx.float32)).astype(dtype)
|
||||||
x = nn.silu(x)
|
x = nn.silu(x)
|
||||||
x = self.conv_out(x)
|
x = self.conv_out(x)
|
||||||
|
|
||||||
|
BIN
stable_diffusion/still-life.png
Normal file
BIN
stable_diffusion/still-life.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 1.5 MiB |
@ -4,26 +4,50 @@ import argparse
|
|||||||
|
|
||||||
import mlx.core as mx
|
import mlx.core as mx
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from mlx.nn import QuantizedLinear
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from stable_diffusion import StableDiffusion
|
from stable_diffusion import StableDiffusion, StableDiffusionXL
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate images from a textual prompt using stable diffusion"
|
description="Generate images from a textual prompt using stable diffusion"
|
||||||
)
|
)
|
||||||
parser.add_argument("prompt")
|
parser.add_argument("prompt")
|
||||||
|
parser.add_argument("--model", choices=["sd", "sdxl"], default="sdxl")
|
||||||
parser.add_argument("--n_images", type=int, default=4)
|
parser.add_argument("--n_images", type=int, default=4)
|
||||||
parser.add_argument("--steps", type=int, default=50)
|
parser.add_argument("--steps", type=int)
|
||||||
parser.add_argument("--cfg", type=float, default=7.5)
|
parser.add_argument("--cfg", type=float)
|
||||||
parser.add_argument("--negative_prompt", default="")
|
parser.add_argument("--negative_prompt", default="")
|
||||||
parser.add_argument("--n_rows", type=int, default=1)
|
parser.add_argument("--n_rows", type=int, default=1)
|
||||||
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
parser.add_argument("--decoding_batch_size", type=int, default=1)
|
||||||
|
parser.add_argument("--no-float16", dest="float16", action="store_false")
|
||||||
|
parser.add_argument("--quantize", "-q", action="store_true")
|
||||||
|
parser.add_argument("--preload-models", action="store_true")
|
||||||
parser.add_argument("--output", default="out.png")
|
parser.add_argument("--output", default="out.png")
|
||||||
|
parser.add_argument("--verbose", "-v", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
sd = StableDiffusion()
|
if args.model == "sdxl":
|
||||||
|
sd = StableDiffusionXL("stabilityai/sdxl-turbo", float16=args.float16)
|
||||||
|
if args.quantize:
|
||||||
|
QuantizedLinear.quantize_module(sd.text_encoder_1)
|
||||||
|
QuantizedLinear.quantize_module(sd.text_encoder_2)
|
||||||
|
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||||
|
args.cfg = args.cfg or 0.0
|
||||||
|
args.steps = args.steps or 2
|
||||||
|
else:
|
||||||
|
sd = StableDiffusion(
|
||||||
|
"stabilityai/stable-diffusion-2-1-base", float16=args.float16
|
||||||
|
)
|
||||||
|
if args.quantize:
|
||||||
|
QuantizedLinear.quantize_module(sd.text_encoder)
|
||||||
|
QuantizedLinear.quantize_module(sd.unet, group_size=32, bits=8)
|
||||||
|
args.cfg = args.cfg or 7.5
|
||||||
|
args.steps = args.steps or 50
|
||||||
|
if args.preload_models:
|
||||||
|
sd.ensure_models_are_loaded()
|
||||||
|
|
||||||
# Generate the latent vectors using diffusion
|
# Generate the latent vectors using diffusion
|
||||||
latents = sd.generate_latents(
|
latents = sd.generate_latents(
|
||||||
@ -36,11 +60,24 @@ if __name__ == "__main__":
|
|||||||
for x_t in tqdm(latents, total=args.steps):
|
for x_t in tqdm(latents, total=args.steps):
|
||||||
mx.eval(x_t)
|
mx.eval(x_t)
|
||||||
|
|
||||||
|
# The following is not necessary but it may help in memory
|
||||||
|
# constrained systems by reusing the memory kept by the unet and the text
|
||||||
|
# encoders.
|
||||||
|
if args.model == "sdxl":
|
||||||
|
del sd.text_encoder_1
|
||||||
|
del sd.text_encoder_2
|
||||||
|
else:
|
||||||
|
del sd.text_encoder
|
||||||
|
del sd.unet
|
||||||
|
del sd.sampler
|
||||||
|
peak_mem_unet = mx.metal.get_peak_memory() / 1024**3
|
||||||
|
|
||||||
# Decode them into images
|
# Decode them into images
|
||||||
decoded = []
|
decoded = []
|
||||||
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
for i in tqdm(range(0, args.n_images, args.decoding_batch_size)):
|
||||||
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
decoded.append(sd.decode(x_t[i : i + args.decoding_batch_size]))
|
||||||
mx.eval(decoded[-1])
|
mx.eval(decoded[-1])
|
||||||
|
peak_mem_overall = mx.metal.get_peak_memory() / 1024**3
|
||||||
|
|
||||||
# Arrange them on a grid
|
# Arrange them on a grid
|
||||||
x = mx.concatenate(decoded, axis=0)
|
x = mx.concatenate(decoded, axis=0)
|
||||||
@ -53,3 +90,8 @@ if __name__ == "__main__":
|
|||||||
# Save them to disc
|
# Save them to disc
|
||||||
im = Image.fromarray(np.array(x))
|
im = Image.fromarray(np.array(x))
|
||||||
im.save(args.output)
|
im.save(args.output)
|
||||||
|
|
||||||
|
# Report the peak memory used during generation
|
||||||
|
if args.verbose:
|
||||||
|
print(f"Peak memory used for the unet: {peak_mem_unet:.3f}GB")
|
||||||
|
print(f"Peak memory used overall: {peak_mem_overall:.3f}GB")
|
||||||
|
Loading…
Reference in New Issue
Block a user