diff --git a/flux/README.md b/flux/README.md index b00a9621..c51d7da0 100644 --- a/flux/README.md +++ b/flux/README.md @@ -167,8 +167,9 @@ python dreambooth.py \ path/to/dreambooth/dataset/dog6 ``` - -Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning. +Or you can directly use the pre-processed Hugging Face dataset +[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) +for fine-tuning. ```shell python dreambooth.py \ @@ -210,3 +211,71 @@ speed during generation. [^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details. [^2]: The images are from unsplash by https://unsplash.com/@alvannee . + + +Distributed Computation +------------------------ + +The FLUX example supports distributed computation during both generation and +training. See the [distributed communication +documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html) +for information on how to set-up MLX for distributed communication. The rest of +this section assumes you can launch distributed MLX programs using `mlx.launch +--hostfile hostfile.json`. + +### Distributed Finetuning + +Distributed finetuning scales very well with FLUX and all one has to do is +simply to adjust the gradient accumulation and iterations so that the batch +size remains the same. For instance, to replicate the following training + +```shell +python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 600 --iterations 1200 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 8 \ + mlx-community/dreambooth-dog6 +``` + +On 4 machines we simply run + +```shell +mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \ + --progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \ + --progress-every 150 --iterations 300 --learning-rate 0.0001 \ + --lora-rank 4 --grad-accumulate 2 \ + mlx-community/dreambooth-dog6 +``` + +Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8. + +### Distributed Inference + +Distributed inference can be divided in two different approaches. The first +approach is the data-parallel approach, where each node generates its own +images and shares them at the end. The second approach is the model-parallel +approach where the model is shared across the nodes and they collaboratively +generate the images. + +The `txt2image.py` script will attempt to choose the best approach depending on +how many images are being generated across the nodes. The model-parallel +approach can be forced by passing the argument `--force-shard`. + +For better performance in the model-parallel approach we suggest that you use a +[thunderbolt +ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring). + +All you have to do once again is use `mlx.launch` as follows + +```shell +mlx.launch --verbose --hostfile hostfile.json -- \ + python txt2image.py --model schnell \ + --n-images 8 \ + --image-size 512x512 \ + --verbose \ + 'A photo of an astronaut riding a horse on Mars' +``` + +for model-parallel generation you may want to also pass `--env +MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that +reduces the CPU/GPU synchronization overhead. diff --git a/flux/txt2image.py b/flux/txt2image.py index fd59b711..cae0a6d9 100644 --- a/flux/txt2image.py +++ b/flux/txt2image.py @@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False): if __name__ == "__main__": parser = argparse.ArgumentParser( - description="Generate images from a textual prompt using stable diffusion" + description="Generate images from a textual prompt using FLUX" ) parser.add_argument("prompt") parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")