FLUX: The dataset is adjusted to train.jsonl

This commit is contained in:
madroid
2024-10-12 19:37:34 +08:00
parent b0e017a16c
commit 1252536b4b
2 changed files with 21 additions and 30 deletions

View File

@@ -23,11 +23,11 @@ class FinetuningDataset:
self.args = args
self.flux = flux
self.dataset_base = Path(args.dataset)
dataset_index = self.dataset_base / "index.json"
if not dataset_index.exists():
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
with open(dataset_index, "r") as f:
self.index = json.load(f)
data_file = self.dataset_base / "train.jsonl"
if not data_file.exists():
raise ValueError(f"The fine-tuning dataset 'train.jsonl' was not found in the '{args.dataset}' path.")
with open(data_file, "r") as fid:
self.data = [json.loads(l) for l in fid]
self.latents = []
self.t5_features = []
@@ -78,7 +78,7 @@ class FinetuningDataset:
def encode_images(self):
"""Encode the images in the latent space to prepare for training."""
self.flux.ae.eval()
for sample in tqdm(self.index["data"]):
for sample in tqdm(self.data, desc="encode images"):
input_img = Image.open(self.dataset_base / sample["image"])
for i in range(self.args.num_augmentations):
img = self._random_crop_resize(input_img)
@@ -91,8 +91,8 @@ class FinetuningDataset:
def encode_prompts(self):
"""Pre-encode the prompts so that we don't recompute them during
training (doesn't allow finetuning the text encoders)."""
for sample in tqdm(self.index["data"]):
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
for sample in tqdm(self.data, desc="encode prompts"):
t5_tok, clip_tok = self.flux.tokenize([sample["prompt"]])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
@@ -272,7 +272,7 @@ if __name__ == "__main__":
trainable_params = tree_reduce(
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
)
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
print(f"Training {trainable_params / 1024 ** 2:.3f}M parameters", flush=True)
# Set up the optimizer and training steps. The steps are a bit verbose to
# support gradient accumulation together with compilation.