FLUX: The dataset is adjusted to train.jsonl

This commit is contained in:
madroid 2024-10-12 19:37:34 +08:00
parent b0e017a16c
commit 1252536b4b
2 changed files with 21 additions and 30 deletions

View File

@ -94,17 +94,12 @@ Finetuning
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an but ymmv) on a provided image dataset. The dataset folder must have an
`index.json` file with the following format: `train.jsonl` file with the following format:
```json ```jsonl
{ {"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
"data": [ {"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"}, ...
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
...
]
}
``` ```
The training script by default trains for 600 iterations with a batch size of The training script by default trains for 600 iterations with a batch size of
@ -126,19 +121,15 @@ The training images are the following 5 images [^2]:
![dog6](static/dog6.png) ![dog6](static/dog6.png)
We start by making the following `index.json` file and placing it in the same We start by making the following `train.jsonl` file and placing it in the same
folder as the images. folder as the images.
```json ```jsonl
{ {"image": "00.jpg", "prompt": "A photo of sks dog"}
"data": [ {"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "00.jpg", "text": "A photo of sks dog"}, {"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "text": "A photo of sks dog"}, {"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "text": "A photo of sks dog"}, {"image": "04.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "text": "A photo of sks dog"},
{"image": "04.jpg", "text": "A photo of sks dog"}
]
}
``` ```
Subsequently we finetune FLUX using the following command: Subsequently we finetune FLUX using the following command:

View File

@ -23,11 +23,11 @@ class FinetuningDataset:
self.args = args self.args = args
self.flux = flux self.flux = flux
self.dataset_base = Path(args.dataset) self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json" data_file = self.dataset_base / "train.jsonl"
if not dataset_index.exists(): if not data_file.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset") raise ValueError(f"The fine-tuning dataset 'train.jsonl' was not found in the '{args.dataset}' path.")
with open(dataset_index, "r") as f: with open(data_file, "r") as fid:
self.index = json.load(f) self.data = [json.loads(l) for l in fid]
self.latents = [] self.latents = []
self.t5_features = [] self.t5_features = []
@ -78,7 +78,7 @@ class FinetuningDataset:
def encode_images(self): def encode_images(self):
"""Encode the images in the latent space to prepare for training.""" """Encode the images in the latent space to prepare for training."""
self.flux.ae.eval() self.flux.ae.eval()
for sample in tqdm(self.index["data"]): for sample in tqdm(self.data, desc="encode images"):
input_img = Image.open(self.dataset_base / sample["image"]) input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations): for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img) img = self._random_crop_resize(input_img)
@ -91,8 +91,8 @@ class FinetuningDataset:
def encode_prompts(self): def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during """Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders).""" training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]): for sample in tqdm(self.data, desc="encode prompts"):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]]) t5_tok, clip_tok = self.flux.tokenize([sample["prompt"]])
t5_feat = self.flux.t5(t5_tok) t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat) mx.eval(t5_feat, clip_feat)
@ -272,7 +272,7 @@ if __name__ == "__main__":
trainable_params = tree_reduce( trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0 lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
) )
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True) print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
# Set up the optimizer and training steps. The steps are a bit verbose to # Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation. # support gradient accumulation together with compilation.