FLUX: ref dataset args

This commit is contained in:
madroid 2024-10-13 21:46:21 +08:00
parent 7a20389c06
commit aed4b007fc
2 changed files with 14 additions and 27 deletions

View File

@ -67,7 +67,7 @@ def setup_arg_parser():
parser.add_argument(
"--model",
default="schnell",
default="dev",
choices=[
"dev",
"schnell",
@ -247,8 +247,7 @@ if __name__ == "__main__":
x, t5_feat, clip_feat, guidance, prev_grads
)
# print("Create the training dataset.", flush=True)
dataset = load_dataset(flux, args)
dataset = load_dataset(args.dataset)
trainer = Trainer(flux, dataset, args)
trainer.encode_dataset()

View File

@ -5,19 +5,9 @@ from PIL import Image
class Dataset:
def __init__(self, flux, args, data):
self.args = args
self.flux = flux
def __init__(self, data):
self._data = data
def __getitem__(self, index: int):
item = self._data[index]
image = item["image"]
prompt = item["prompt"]
return image, prompt
def __len__(self):
if self._data is None:
return 0
@ -26,12 +16,12 @@ class Dataset:
class LocalDataset(Dataset):
def __init__(self, flux, args, data_file):
self.dataset_base = Path(args.dataset)
def __init__(self, dataset: str, data_file):
self.dataset_base = Path(dataset)
with open(data_file, "r") as fid:
self._data = [json.loads(l) for l in fid]
super().__init__(flux, args, self._data)
super().__init__(self._data)
def __getitem__(self, index: int):
item = self._data[index]
@ -41,29 +31,27 @@ class LocalDataset(Dataset):
class HuggingFaceDataset(Dataset):
def __init__(self, flux, args):
def __init__(self, dataset: str):
from datasets import load_dataset
df = load_dataset(args.dataset)["train"]
df = load_dataset(dataset)["train"]
self._data = df.data
super().__init__(flux, args, df)
super().__init__(df)
def __getitem__(self, index: int):
item = self._data[index]
return item["image"], item["prompt"]
def load_dataset(flux, args):
dataset_base = Path(args.dataset)
def load_dataset(dataset: str):
dataset_base = Path(dataset)
data_file = dataset_base / "train.jsonl"
if data_file.exists():
print(f"Load the local dataset {data_file} .", flush=True)
# print(f"load local dataset: {data_file}")
dataset = LocalDataset(flux, args, data_file)
dataset = LocalDataset(dataset, data_file)
else:
print(f"Load the Hugging Face dataset {args.dataset} .", flush=True)
# print(f"load Hugging Face dataset: {args.dataset}")
dataset = HuggingFaceDataset(flux, args)
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
dataset = HuggingFaceDataset(dataset)
return dataset