Update README

This commit is contained in:
Angelos Katharopoulos 2025-03-22 17:44:20 -07:00
parent c109d9b596
commit c4d08de8b3
2 changed files with 72 additions and 3 deletions

View File

@ -167,8 +167,9 @@ python dreambooth.py \
path/to/dreambooth/dataset/dog6
```
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
Or you can directly use the pre-processed Hugging Face dataset
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
for fine-tuning.
```shell
python dreambooth.py \
@ -210,3 +211,71 @@ speed during generation.
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
Distributed Computation
------------------------
The FLUX example supports distributed computation during both generation and
training. See the [distributed communication
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
for information on how to set-up MLX for distributed communication. The rest of
this section assumes you can launch distributed MLX programs using `mlx.launch
--hostfile hostfile.json`.
### Distributed Finetuning
Distributed finetuning scales very well with FLUX and all one has to do is
simply to adjust the gradient accumulation and iterations so that the batch
size remains the same. For instance, to replicate the following training
```shell
python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 8 \
mlx-community/dreambooth-dog6
```
On 4 machines we simply run
```shell
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
--lora-rank 4 --grad-accumulate 2 \
mlx-community/dreambooth-dog6
```
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
### Distributed Inference
Distributed inference can be divided in two different approaches. The first
approach is the data-parallel approach, where each node generates its own
images and shares them at the end. The second approach is the model-parallel
approach where the model is shared across the nodes and they collaboratively
generate the images.
The `txt2image.py` script will attempt to choose the best approach depending on
how many images are being generated across the nodes. The model-parallel
approach can be forced by passing the argument `--force-shard`.
For better performance in the model-parallel approach we suggest that you use a
[thunderbolt
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
All you have to do once again is use `mlx.launch` as follows
```shell
mlx.launch --verbose --hostfile hostfile.json -- \
python txt2image.py --model schnell \
--n-images 8 \
--image-size 512x512 \
--verbose \
'A photo of an astronaut riding a horse on Mars'
```
for model-parallel generation you may want to also pass `--env
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
reduces the CPU/GPU synchronization overhead.

View File

@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Generate images from a textual prompt using stable diffusion"
description="Generate images from a textual prompt using FLUX"
)
parser.add_argument("prompt")
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")