mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-25 18:11:17 +08:00
Update README
This commit is contained in:
parent
c109d9b596
commit
c4d08de8b3
@ -167,8 +167,9 @@ python dreambooth.py \
|
|||||||
path/to/dreambooth/dataset/dog6
|
path/to/dreambooth/dataset/dog6
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Or you can directly use the pre-processed Hugging Face dataset
|
||||||
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
|
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
|
||||||
|
for fine-tuning.
|
||||||
|
|
||||||
```shell
|
```shell
|
||||||
python dreambooth.py \
|
python dreambooth.py \
|
||||||
@ -210,3 +211,71 @@ speed during generation.
|
|||||||
|
|
||||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
|
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
|
||||||
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
|
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
|
||||||
|
|
||||||
|
|
||||||
|
Distributed Computation
|
||||||
|
------------------------
|
||||||
|
|
||||||
|
The FLUX example supports distributed computation during both generation and
|
||||||
|
training. See the [distributed communication
|
||||||
|
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
|
||||||
|
for information on how to set-up MLX for distributed communication. The rest of
|
||||||
|
this section assumes you can launch distributed MLX programs using `mlx.launch
|
||||||
|
--hostfile hostfile.json`.
|
||||||
|
|
||||||
|
### Distributed Finetuning
|
||||||
|
|
||||||
|
Distributed finetuning scales very well with FLUX and all one has to do is
|
||||||
|
simply to adjust the gradient accumulation and iterations so that the batch
|
||||||
|
size remains the same. For instance, to replicate the following training
|
||||||
|
|
||||||
|
```shell
|
||||||
|
python dreambooth.py \
|
||||||
|
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||||
|
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||||
|
--lora-rank 4 --grad-accumulate 8 \
|
||||||
|
mlx-community/dreambooth-dog6
|
||||||
|
```
|
||||||
|
|
||||||
|
On 4 machines we simply run
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
|
||||||
|
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||||
|
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
|
||||||
|
--lora-rank 4 --grad-accumulate 2 \
|
||||||
|
mlx-community/dreambooth-dog6
|
||||||
|
```
|
||||||
|
|
||||||
|
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
|
||||||
|
|
||||||
|
### Distributed Inference
|
||||||
|
|
||||||
|
Distributed inference can be divided in two different approaches. The first
|
||||||
|
approach is the data-parallel approach, where each node generates its own
|
||||||
|
images and shares them at the end. The second approach is the model-parallel
|
||||||
|
approach where the model is shared across the nodes and they collaboratively
|
||||||
|
generate the images.
|
||||||
|
|
||||||
|
The `txt2image.py` script will attempt to choose the best approach depending on
|
||||||
|
how many images are being generated across the nodes. The model-parallel
|
||||||
|
approach can be forced by passing the argument `--force-shard`.
|
||||||
|
|
||||||
|
For better performance in the model-parallel approach we suggest that you use a
|
||||||
|
[thunderbolt
|
||||||
|
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
|
||||||
|
|
||||||
|
All you have to do once again is use `mlx.launch` as follows
|
||||||
|
|
||||||
|
```shell
|
||||||
|
mlx.launch --verbose --hostfile hostfile.json -- \
|
||||||
|
python txt2image.py --model schnell \
|
||||||
|
--n-images 8 \
|
||||||
|
--image-size 512x512 \
|
||||||
|
--verbose \
|
||||||
|
'A photo of an astronaut riding a horse on Mars'
|
||||||
|
```
|
||||||
|
|
||||||
|
for model-parallel generation you may want to also pass `--env
|
||||||
|
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
|
||||||
|
reduces the CPU/GPU synchronization overhead.
|
||||||
|
@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
|
|||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Generate images from a textual prompt using stable diffusion"
|
description="Generate images from a textual prompt using FLUX"
|
||||||
)
|
)
|
||||||
parser.add_argument("prompt")
|
parser.add_argument("prompt")
|
||||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||||
|
Loading…
Reference in New Issue
Block a user