mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-24 17:31:18 +08:00
Update README
This commit is contained in:
parent
c109d9b596
commit
c4d08de8b3
@ -167,8 +167,9 @@ python dreambooth.py \
|
||||
path/to/dreambooth/dataset/dog6
|
||||
```
|
||||
|
||||
|
||||
Or you can directly use the pre-processed Hugging Face dataset [mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6) for fine-tuning.
|
||||
Or you can directly use the pre-processed Hugging Face dataset
|
||||
[mlx-community/dreambooth-dog6](https://huggingface.co/datasets/mlx-community/dreambooth-dog6)
|
||||
for fine-tuning.
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
@ -210,3 +211,71 @@ speed during generation.
|
||||
|
||||
[^1]: Refer to the [arXiv paper](https://arxiv.org/abs/2208.12242) for more details.
|
||||
[^2]: The images are from unsplash by https://unsplash.com/@alvannee .
|
||||
|
||||
|
||||
Distributed Computation
|
||||
------------------------
|
||||
|
||||
The FLUX example supports distributed computation during both generation and
|
||||
training. See the [distributed communication
|
||||
documentation](https://ml-explore.github.io/mlx/build/html/usage/distributed.html)
|
||||
for information on how to set-up MLX for distributed communication. The rest of
|
||||
this section assumes you can launch distributed MLX programs using `mlx.launch
|
||||
--hostfile hostfile.json`.
|
||||
|
||||
### Distributed Finetuning
|
||||
|
||||
Distributed finetuning scales very well with FLUX and all one has to do is
|
||||
simply to adjust the gradient accumulation and iterations so that the batch
|
||||
size remains the same. For instance, to replicate the following training
|
||||
|
||||
```shell
|
||||
python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 600 --iterations 1200 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 8 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
On 4 machines we simply run
|
||||
|
||||
```shell
|
||||
mlx.launch --verbose --hostfile hostfile.json -- python dreambooth.py \
|
||||
--progress-prompt 'A photo of an sks dog lying on the sand at a beach in Greece' \
|
||||
--progress-every 150 --iterations 300 --learning-rate 0.0001 \
|
||||
--lora-rank 4 --grad-accumulate 2 \
|
||||
mlx-community/dreambooth-dog6
|
||||
```
|
||||
|
||||
Note the iterations that changed to 300 from 1200 and the gradient accumulations to 2 from 8.
|
||||
|
||||
### Distributed Inference
|
||||
|
||||
Distributed inference can be divided in two different approaches. The first
|
||||
approach is the data-parallel approach, where each node generates its own
|
||||
images and shares them at the end. The second approach is the model-parallel
|
||||
approach where the model is shared across the nodes and they collaboratively
|
||||
generate the images.
|
||||
|
||||
The `txt2image.py` script will attempt to choose the best approach depending on
|
||||
how many images are being generated across the nodes. The model-parallel
|
||||
approach can be forced by passing the argument `--force-shard`.
|
||||
|
||||
For better performance in the model-parallel approach we suggest that you use a
|
||||
[thunderbolt
|
||||
ring](https://ml-explore.github.io/mlx/build/html/usage/distributed.html#getting-started-with-ring).
|
||||
|
||||
All you have to do once again is use `mlx.launch` as follows
|
||||
|
||||
```shell
|
||||
mlx.launch --verbose --hostfile hostfile.json -- \
|
||||
python txt2image.py --model schnell \
|
||||
--n-images 8 \
|
||||
--image-size 512x512 \
|
||||
--verbose \
|
||||
'A photo of an astronaut riding a horse on Mars'
|
||||
```
|
||||
|
||||
for model-parallel generation you may want to also pass `--env
|
||||
MLX_METAL_FAST_SYNCH=1` to `mlx.launch` which is an experimental setting that
|
||||
reduces the CPU/GPU synchronization overhead.
|
||||
|
@ -41,7 +41,7 @@ def load_adapter(flux, adapter_file, fuse=False):
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate images from a textual prompt using stable diffusion"
|
||||
description="Generate images from a textual prompt using FLUX"
|
||||
)
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("--model", choices=["schnell", "dev"], default="schnell")
|
||||
|
Loading…
Reference in New Issue
Block a user