mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
96 lines
3.6 KiB
Markdown
96 lines
3.6 KiB
Markdown
![]() |
Stable Diffusion
|
||
|
================
|
||
|
|
||
|
Stable Diffusion in MLX. The implementation was ported from Hugginface's
|
||
|
[diffusers](https://huggingface.co/docs/diffusers/index) and we are fetching
|
||
|
and using the weights available on the Huggingface Hub by Stability AI at
|
||
|
[stabilitiai/stable-diffusion-2-1](https://huggingface.co/stabilityai/stable-diffusion-2-1).
|
||
|
|
||
|

|
||
|
*Image generated using Stable Diffusion in MLX and the prompt 'A big red sign saying MLX in capital letters.'*
|
||
|
|
||
|
Installation
|
||
|
------------
|
||
|
|
||
|
The dependencies are minimal, namely:
|
||
|
|
||
|
- `safetensors` and `huggingface-hub` to load the checkpoints.
|
||
|
- `regex` for the tokenization
|
||
|
- `numpy` because safetensors needs to return some form of array
|
||
|
- `tqdm` and `PIL` for the `txt2image.py` script
|
||
|
|
||
|
You can install all of the above with the `requirements.txt` as follows:
|
||
|
|
||
|
pip install -r requirements.txt
|
||
|
|
||
|
Usage
|
||
|
------
|
||
|
|
||
|
Although each component in this repository can be used by itsself, the fastest
|
||
|
way to get started is by using the `StableDiffusion` class from the `diffusion`
|
||
|
module.
|
||
|
|
||
|
```python
|
||
|
from stable_diffusion import StableDiffusion
|
||
|
|
||
|
# This will download all the weights from HF hub and load the models in
|
||
|
# memory
|
||
|
sd = StableDiffusion()
|
||
|
|
||
|
# This creates a python generator that returns the latent produced by the
|
||
|
# reverse diffusion process.
|
||
|
#
|
||
|
# Because MLX is lazily evaluated iterating over this generator doesn't
|
||
|
# actually perform the computation until mx.eval() is called.
|
||
|
latent_generator = sd.generate_latents("A photo of an astronaut riding a horse on Mars.")
|
||
|
|
||
|
# Here we are evaluating each diffusion step but we could also evaluate
|
||
|
# once at the end.
|
||
|
for x_t in latent_generator:
|
||
|
mx.simplify(x_t) # remove possible redundant computation eg reuse
|
||
|
# scalars etc
|
||
|
mx.eval(x_t)
|
||
|
|
||
|
# Now x_t is the last latent from the reverse process aka x_0. We can
|
||
|
# decode it into an image using the stable diffusion VAE.
|
||
|
im = sd.decode(x_t)
|
||
|
```
|
||
|
|
||
|
The above is almost line for line the implementation of the `txt2image.py`
|
||
|
script in the root of the repository. You can use the script as follows:
|
||
|
|
||
|
python txt2image.py "A photo of an astronaut riding a horse on Mars." --n_images 4 --n_rows 2
|
||
|
|
||
|
Performance
|
||
|
-----------
|
||
|
|
||
|
The following table compares the performance of the UNet in stable diffusion.
|
||
|
We report throughput in images per second for the provided `txt2image.py`
|
||
|
script and the `diffusers` library using the MPS PyTorch backend.
|
||
|
|
||
|
At the time of writing this comparison convolutions are still some of the least
|
||
|
optimized operations in MLX. Despite that, MLX still achieves **~40% higher
|
||
|
throughput** than PyTorch with a batch size of 16 and ~15% higher when
|
||
|
comparing the optimal batch sizes.
|
||
|
|
||
|
Notably, PyTorch achieves almost ~50% higher throughput for the batch size of 1
|
||
|
which is unfortunate as that means that a single image can be computed faster.
|
||
|
However, when starting with the models not loaded in memory and PyTorch's MPS
|
||
|
graph kernels not cached, the compilation time more than accounts for this
|
||
|
speed difference.
|
||
|
|
||
|
| Batch size | PyTorch | MLX |
|
||
|
| ---------- | ----------- | ----------- |
|
||
|
| 1 | 6.25 im/s | 4.17 im/s |
|
||
|
| 2 | 7.14 im/s | 5.88 im/s |
|
||
|
| 4 |**7.69 im/s**| 7.14 im/s |
|
||
|
| 6 | 7.22 im/s | 8.00 im/s |
|
||
|
| 8 | 6.89 im/s | 8.42 im/s |
|
||
|
| 12 | 6.62 im/s | 8.51 im/s |
|
||
|
| 16 | 6.32 im/s |**8.79 im/s**|
|
||
|
|
||
|
The above experiments were made on an M2 Ultra with PyTorch version 2.1,
|
||
|
diffusers version 0.21.4 and transformers version 4.33.3. For the generation we
|
||
|
used classifier free guidance which means that the above batch sizes result
|
||
|
double the images processed by the UNet.
|