mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Add a README and requirements.txt
This commit is contained in:
parent
76f5faba62
commit
7c8c5818f7
180
flux/README.md
Normal file
180
flux/README.md
Normal file
@ -0,0 +1,180 @@
|
||||
FLUX
|
||||
====
|
||||
|
||||
FLUX implementation in MLX. The implementation is ported directly from
|
||||
[https://github.com/black-forest-labs/flux](https://github.com/black-forest-labs/flux)
|
||||
and the model weights are downloaded directly from the Hugging Face Hub.
|
||||
|
||||
The goal of this example is to be clean, educational and to allow for
|
||||
experimentation with finetuning FLUX models as well as adding extra
|
||||
functionality such as in-/outpainting, guidance with custom losses etc.
|
||||
|
||||

|
||||
*Image generated using FLUX-dev in MLX and the prompt 'An image in the style of
|
||||
tron emanating futuristic technology with the word "MLX" in the center with
|
||||
capital red letters.'*
|
||||
|
||||
Installation
|
||||
------------
|
||||
|
||||
The dependencies are minimal, namely:
|
||||
|
||||
- `huggingface-hub` to download the checkpoints.
|
||||
- `regex` for the tokenization
|
||||
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
|
||||
- `sentencepiece` for the T5 tokenizer
|
||||
|
||||
You can install all of the above with the `requirements.txt` as follows:
|
||||
|
||||
pip install -r requirements.txt
|
||||
|
||||
Inference
|
||||
---------
|
||||
|
||||
Inference in this example is similar to the stable diffusion example. The
|
||||
classes to get you started are `FluxPipeline` from the `flux` module.
|
||||
|
||||
```python
|
||||
import mlx.core as mx
|
||||
from flux import FluxPipeline
|
||||
|
||||
# This will download all the weights from HF hub
|
||||
flux = FluxPipeline("flux-schnell")
|
||||
|
||||
# Make a generator that returns the latent variables from the reverse diffusion
|
||||
# process
|
||||
latent_generator = flux.generate_latents(
|
||||
"A photo of an astronaut riding a horse on Mars",
|
||||
num_steps=4,
|
||||
latent_size=(32, 64), # 256x512 image
|
||||
)
|
||||
|
||||
# The first return value of the generator contains the conditioning and the
|
||||
# random noise at the beginning of the diffusion process.
|
||||
conditioning = next(latent_generator)
|
||||
(
|
||||
x_T, # The initial noise
|
||||
x_positions, # The integer positions used for image positional encoding
|
||||
t5_conditioning, # The T5 features from the text prompt
|
||||
t5_positions, # Integer positions for text (normally all 0s)
|
||||
clip_conditioning, # The clip text features from the text prompt
|
||||
) = conditioning
|
||||
|
||||
# Returning the conditioning as the first output from the generator allows us
|
||||
# to unload T5 and clip before running the diffusion transformer.
|
||||
mx.eval(conditioning)
|
||||
|
||||
# Evaluate each diffusion step
|
||||
for x_t in latent_generator:
|
||||
mx.eval(x_t)
|
||||
|
||||
# Note that we need to pass the latent size because it is collapsed and
|
||||
# patchified in x_t and we need to unwrap it.
|
||||
img = flux.decode(x_t, latent_size=(32, 64))
|
||||
```
|
||||
|
||||
The above are essentially the implementation of the `txt2image.py` script
|
||||
except for some additional logic to quantize and/or load trained adapters. One
|
||||
can use the script as follows:
|
||||
|
||||
```shell
|
||||
python txt2image.py --n-images 4 --n-rows 2 --image-size 256x512 'A photo of an astronaut riding a horse on Mars.'
|
||||
```
|
||||
|
||||
### Experimental Options
|
||||
|
||||
FLUX pads the prompt to a specific size of 512 tokens for the dev model and
|
||||
256 for the schnell model. Not applying padding results in faster generation
|
||||
but it is not clear how it may affect the generated images. To enable that
|
||||
option in this example pass `--no-t5-padding` to the `txt2image.py` script or
|
||||
instantiate the pipeline with `FluxPipeline("flux-schnell", t5_padding=False)`.
|
||||
|
||||
Finetuning
|
||||
----------
|
||||
|
||||
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
|
||||
but ymmv) on a provided image dataset. The dataset folder must have an
|
||||
`index.json` file with the following format:
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
||||
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
|
||||
...
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
The training script by default trains for 600 iterations with a batch size of
|
||||
1, gradient accumulation of 4 and LoRA rank of 8. Run `python dreambooth.py
|
||||
--help` for the list of hyperparameters you can tune.
|
||||
|
||||
> [!Note]
|
||||
> FLUX finetuning requires approximately 50GB of RAM. QLoRA is coming soon and
|
||||
> should reduce this number significantly.
|
||||
|
||||
### Training Example
|
||||
|
||||
This is a step-by-step finetuning example. We will be using the data from
|
||||
[https://github.com/google/dreambooth](https://github.com/google/dreambooth).
|
||||
In particular, we will use `dog6` which is a popular example for showcasing
|
||||
dreambooth [^1].
|
||||
|
||||
We start by making the following `index.json` file and placing it in the same
|
||||
folder as the images.
|
||||
|
||||
```json
|
||||
{
|
||||
"data": [
|
||||
{"image": "00.jpg", "text": "A photo of sks dog"},
|
||||
{"image": "01.jpg", "text": "A photo of sks dog"},
|
||||
{"image": "02.jpg", "text": "A photo of sks dog"},
|
||||
{"image": "03.jpg", "text": "A photo of sks dog"},
|
||||
{"image": "04.jpg", "text": "A photo of sks dog"}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
Subsequently we finetune FLUX using the following command:
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 8 \
|
||||
path/to/dreambooth/dataset/dog6
|
||||
```
|
||||
|
||||
The training requires approximately 50GB of RAM and on an M2 Ultra it takes a
|
||||
bit more than 1 hour.
|
||||
|
||||
### Using the Adapter
|
||||
|
||||
The adapters are saved in `mlx_output` and can be used directly by the
|
||||
`txt2image.py` script. For instance,
|
||||
|
||||
```shell
|
||||
python txt2img.py --model dev --save-raw --image-size 512x512 --n-images 1 \
|
||||
--adapter mlx_output/mlx_output/0001200_adapters.safetensors \
|
||||
--fuse-adapter \
|
||||
--no-t5-padding \
|
||||
'A photo of an sks dog lying on the sand at a beach in Greece'
|
||||
```
|
||||
|
||||
generates an image that looks like the following,
|
||||
|
||||

|
||||
|
||||
and of course we can pass `--image-size 512x1024` to get larger images with
|
||||
different aspect ratios,
|
||||
|
||||

|
||||
|
||||
The arguments that are relevant to the adapters are of course `--adapter` and
|
||||
`--fuse-adapter`. The first defines the path to an adapter to apply to the
|
||||
model and the second fuses the adapter back into the model to get a bit more
|
||||
speed during generation.
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
|
7
flux/requirements.txt
Normal file
7
flux/requirements.txt
Normal file
@ -0,0 +1,7 @@
|
||||
mlx>=0.18.1
|
||||
huggingface-hub
|
||||
regex
|
||||
numpy
|
||||
tqdm
|
||||
Pillow
|
||||
sentencepiece
|
BIN
flux/static/dog-r4-g8-1200-512x1024.png
Normal file
BIN
flux/static/dog-r4-g8-1200-512x1024.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 754 KiB |
BIN
flux/static/dog-r4-g8-1200.png
Normal file
BIN
flux/static/dog-r4-g8-1200.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 423 KiB |
BIN
flux/static/generated-mlx.png
Normal file
BIN
flux/static/generated-mlx.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 153 KiB |
Loading…
Reference in New Issue
Block a user