Stable diffusion XL (#516)

This commit is contained in:
Angelos Katharopoulos
2024-03-08 10:24:19 -08:00
committed by GitHub
parent 8c2cf665ed
commit 3a9e6c3f70
11 changed files with 449 additions and 105 deletions

View File

@@ -2,22 +2,25 @@ Stable Diffusion
================
Stable Diffusion in MLX. The implementation was ported from Hugging Face's
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
and using the weights available on the Hugging Face Hub by Stability AI at
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
[diffusers](https://huggingface.co/docs/diffusers/index) and model weights are
downloaded directly from the Hugging Face hub. The implementation currently
supports the following models:
- [stabilityai/sdxl-turbo](https://huggingface.co/stabilityai/sdxl-turbo)
- [stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1)
![out](generated-mlx.png)
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign saying MLX in capital letters.'*
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign
saying MLX in capital letters.'*
Installation
------------
The dependencies are minimal, namely:
- `safetensors` and `huggingface-hub` to load the checkpoints.
- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `numpy` because safetensors needs to return some form of array
- `tqdm` and `PIL` for the `txt2image.py` script
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
You can install all of the above with the `requirements.txt` as follows:
@@ -43,7 +46,9 @@ sd = StableDiffusion()
#
# Because MLX is lazily evaluated iterating over this generator doesn't
# actually perform the computation until mx.eval() is called.
latent_generator = sd.generate_latents("A photo of an astronaut riding a horse on Mars.")
latent_generator = sd.generate_latents(
"A photo of an astronaut riding a horse on Mars."
)
# Here we are evaluating each diffusion step but we could also evaluate
# once at the end.
@@ -55,10 +60,16 @@ for x_t in latent_generator:
im = sd.decode(x_t)
```
The above is almost line for line the implementation of the `txt2image.py`
script in the root of the repository. You can use the script as follows:
The above is essentially the implementation of the `txt2image.py` script in the
root of the repository. You can use the script as follows:
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
```shell
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
```
You can select the model using `--model` argument. Currently supported models
are `sdxl` (default) and `sd`.
Image 2 Image
-------------
@@ -71,48 +82,36 @@ to the forward diffusion process and the `strength` parameter. A `strength` of
random noise.
![image2image](im2im.png)
*Generations with varying strength using the original image and the prompt 'A lit fireplace'.*
The command to generate the above images is:
python image2image.py --strength 0.5 original.png 'A lit fireplace'
```shell
python image2image.py --strength 0.5 original.png 'A lit fireplace'
```
*Note: `image2image.py` will automatically downsample your input image to guarantee that its dimensions are divisible by 64. If you want full control of this process, resize your image prior to using the script.*
> [!Note]
> `image2image.py` will automatically downsample your input image to guarantee
> that its dimensions are divisible by 64. If you want full control of this
> process, resize your image prior to using the script.
Performance
-----------
Memory constrained devices
--------------------------
The following table compares the performance of the UNet in stable diffusion.
We report throughput in images per second **processed by the UNet** for the
provided `txt2image.py` script and the `diffusers` library using the MPS
PyTorch backend.
The `txt2image.py` script by default loads the model in float16 which reduces
significantly the required memory for image generation. However, since the
Stable Diffusion XL UNet alone has 2.6B parameters in order to use it in
devices with 8GB of RAM, quantization is practically necessary.
At the time of writing this comparison convolutions are still some of the least
optimized operations in MLX. Despite that, MLX still achieves **~40% higher
throughput** than PyTorch with a batch size of 16 and ~15% higher when
comparing the optimal batch sizes.
The `txt2image.py` script supports quantization using the `-q` or `--quantize`
command line arguments. When quantization is used, the script quantizes the
text encoder models to 4 bits and the unet to 8 bits. This allows generating
images on an 8GB Mac Mini with no-swapping.
Notably, PyTorch achieves almost ~50% higher throughput for the batch size of 1
which is unfortunate as that means that a single image can be computed faster.
However, when starting with the models not loaded in memory and PyTorch's MPS
graph kernels not cached, the compilation time more than accounts for this
speed difference.
```
python txt2image.py --n_images 4 -q -v --output still-life.png "A painting of a vase on a wooden table, dark background, still life."
```
| Batch size | PyTorch | MLX |
| ---------- | ----------- | ----------- |
| 1 | 6.25 im/s | 4.17 im/s |
| 2 | 7.14 im/s | 5.88 im/s |
| 4 |**7.69 im/s**| 7.14 im/s |
| 6 | 7.22 im/s | 8.00 im/s |
| 8 | 6.89 im/s | 8.42 im/s |
| 12 | 6.62 im/s | 8.51 im/s |
| 16 | 6.32 im/s |**8.79 im/s**|
The above experiments were made on an M2 Ultra with PyTorch version 2.1,
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
used classifier free guidance which means that the above batch sizes result
double the images processed by the UNet.
Note that the above table means that it takes about 90 seconds to fully
generate 16 images with MLX and 50 diffusion steps with classifier free
guidance and about 120 for PyTorch.
![painting](still-life.png)
*Image generated using Stable Diffusion XL turbo in MLX with the above command on an 8GB M1 Mac mini*