From aed4b007fc0a77bb1abe75d152bc21784f69bf9b Mon Sep 17 00:00:00 2001 From: madroid Date: Sun, 13 Oct 2024 21:46:21 +0800 Subject: [PATCH] FLUX: ref dataset args --- flux/dreambooth.py | 5 ++--- flux/flux/datasets.py | 36 ++++++++++++------------------------ 2 files changed, 14 insertions(+), 27 deletions(-) diff --git a/flux/dreambooth.py b/flux/dreambooth.py index 444f6a1e..48dcad47 100644 --- a/flux/dreambooth.py +++ b/flux/dreambooth.py @@ -67,7 +67,7 @@ def setup_arg_parser(): parser.add_argument( "--model", - default="schnell", + default="dev", choices=[ "dev", "schnell", @@ -247,8 +247,7 @@ if __name__ == "__main__": x, t5_feat, clip_feat, guidance, prev_grads ) - # print("Create the training dataset.", flush=True) - dataset = load_dataset(flux, args) + dataset = load_dataset(args.dataset) trainer = Trainer(flux, dataset, args) trainer.encode_dataset() diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py index 5a845d82..e705cf32 100644 --- a/flux/flux/datasets.py +++ b/flux/flux/datasets.py @@ -5,19 +5,9 @@ from PIL import Image class Dataset: - def __init__(self, flux, args, data): - self.args = args - self.flux = flux - + def __init__(self, data): self._data = data - def __getitem__(self, index: int): - item = self._data[index] - image = item["image"] - prompt = item["prompt"] - - return image, prompt - def __len__(self): if self._data is None: return 0 @@ -26,12 +16,12 @@ class Dataset: class LocalDataset(Dataset): - def __init__(self, flux, args, data_file): - self.dataset_base = Path(args.dataset) + def __init__(self, dataset: str, data_file): + self.dataset_base = Path(dataset) with open(data_file, "r") as fid: self._data = [json.loads(l) for l in fid] - super().__init__(flux, args, self._data) + super().__init__(self._data) def __getitem__(self, index: int): item = self._data[index] @@ -41,29 +31,27 @@ class LocalDataset(Dataset): class HuggingFaceDataset(Dataset): - def __init__(self, flux, args): + def __init__(self, dataset: str): from datasets import load_dataset - df = load_dataset(args.dataset)["train"] + df = load_dataset(dataset)["train"] self._data = df.data - super().__init__(flux, args, df) + super().__init__(df) def __getitem__(self, index: int): item = self._data[index] return item["image"], item["prompt"] -def load_dataset(flux, args): - dataset_base = Path(args.dataset) +def load_dataset(dataset: str): + dataset_base = Path(dataset) data_file = dataset_base / "train.jsonl" if data_file.exists(): print(f"Load the local dataset {data_file} .", flush=True) - # print(f"load local dataset: {data_file}") - dataset = LocalDataset(flux, args, data_file) + dataset = LocalDataset(dataset, data_file) else: - print(f"Load the Hugging Face dataset {args.dataset} .", flush=True) - # print(f"load Hugging Face dataset: {args.dataset}") - dataset = HuggingFaceDataset(flux, args) + print(f"Load the Hugging Face dataset {dataset} .", flush=True) + dataset = HuggingFaceDataset(dataset) return dataset