FLUX: The dataset is adjusted to train.jsonl

This commit is contained in:
madroid 2024-10-12 19:37:34 +08:00
parent b0e017a16c
commit 1252536b4b
2 changed files with 21 additions and 30 deletions

View File

@ -94,17 +94,12 @@ Finetuning
The `dreambooth.py` script supports LoRA finetuning of FLUX-dev (and schnell
but ymmv) on a provided image dataset. The dataset folder must have an
`index.json` file with the following format:
`train.jsonl` file with the following format:
```json
{
"data": [
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
{"image": "path-to-image-relative-to-dataset", "text": "Prompt to use with this image"},
```jsonl
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
{"image": "path-to-image-relative-to-dataset", "prompt": "Prompt to use with this image"}
...
]
}
```
The training script by default trains for 600 iterations with a batch size of
@ -126,19 +121,15 @@ The training images are the following 5 images [^2]:
![dog6](static/dog6.png)
We start by making the following `index.json` file and placing it in the same
We start by making the following `train.jsonl` file and placing it in the same
folder as the images.
```json
{
"data": [
{"image": "00.jpg", "text": "A photo of sks dog"},
{"image": "01.jpg", "text": "A photo of sks dog"},
{"image": "02.jpg", "text": "A photo of sks dog"},
{"image": "03.jpg", "text": "A photo of sks dog"},
{"image": "04.jpg", "text": "A photo of sks dog"}
]
}
```jsonl
{"image": "00.jpg", "prompt": "A photo of sks dog"}
{"image": "01.jpg", "prompt": "A photo of sks dog"}
{"image": "02.jpg", "prompt": "A photo of sks dog"}
{"image": "03.jpg", "prompt": "A photo of sks dog"}
{"image": "04.jpg", "prompt": "A photo of sks dog"}
```
Subsequently we finetune FLUX using the following command:

View File

@ -23,11 +23,11 @@ class FinetuningDataset:
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)
data_file = self.dataset_base / "train.jsonl"
if not data_file.exists():
raise ValueError(f"The fine-tuning dataset 'train.jsonl' was not found in the '{args.dataset}' path.")
with open(data_file, "r") as fid:
self.data = [json.loads(l) for l in fid]
self.latents = []
self.t5_features = []
@ -78,7 +78,7 @@ class FinetuningDataset:
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
for sample in tqdm(self.data, desc="encode images"):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
@ -91,8 +91,8 @@ class FinetuningDataset:
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
for sample in tqdm(self.data, desc="encode prompts"):
t5_tok, clip_tok = self.flux.tokenize([sample["prompt"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)