diff --git a/flux/README.md b/flux/README.md index 3e10cbf5..fdd8d905 100644 --- a/flux/README.md +++ b/flux/README.md @@ -21,8 +21,9 @@ The dependencies are minimal, namely: - `huggingface-hub` to download the checkpoints. - `regex` for the tokenization -- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script +- `tqdm`, `PIL`, and `numpy` for the scripts - `sentencepiece` for the T5 tokenizer +- `datasets` for using an HF dataset directly You can install all of the above with the `requirements.txt` as follows: diff --git a/flux/flux/datasets.py b/flux/flux/datasets.py index e705cf32..d31a09f1 100644 --- a/flux/flux/datasets.py +++ b/flux/flux/datasets.py @@ -5,51 +5,69 @@ from PIL import Image class Dataset: - def __init__(self, data): - self._data = data + def __getitem__(self, index: int): + raise NotImplementedError() def __len__(self): - if self._data is None: - return 0 - return len(self._data) + raise NotImplementedError() class LocalDataset(Dataset): + prompt_key = "prompt" def __init__(self, dataset: str, data_file): self.dataset_base = Path(dataset) with open(data_file, "r") as fid: self._data = [json.loads(l) for l in fid] - super().__init__(self._data) + def __len__(self): + return len(self._data) def __getitem__(self, index: int): item = self._data[index] image = Image.open(self.dataset_base / item["image"]) - return image, item["prompt"] + return image, item[self.prompt_key] + + +class LegacyDataset(LocalDataset): + prompt_key = "text" + + def __init__(self, dataset: str): + self.dataset_base = Path(dataset) + with open(self.dataset_base / "index.json") as f: + self._data = json.load(f)["data"] class HuggingFaceDataset(Dataset): def __init__(self, dataset: str): - from datasets import load_dataset + from datasets import load_dataset as hf_load_dataset - df = load_dataset(dataset)["train"] - self._data = df.data - super().__init__(df) + self._df = hf_load_dataset(dataset)["train"] + + def __len__(self): + return len(self._df) def __getitem__(self, index: int): - item = self._data[index] + item = self._df[index] return item["image"], item["prompt"] def load_dataset(dataset: str): dataset_base = Path(dataset) data_file = dataset_base / "train.jsonl" + legacy_file = dataset_base / "index.json" if data_file.exists(): print(f"Load the local dataset {data_file} .", flush=True) dataset = LocalDataset(dataset, data_file) + elif legacy_file.exists(): + print(f"Load the local dataset {legacy_file} .") + print() + print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.") + print(" See the README for details.") + print(flush=True) + dataset = LegacyDataset(dataset) else: print(f"Load the Hugging Face dataset {dataset} .", flush=True) dataset = HuggingFaceDataset(dataset) diff --git a/flux/flux/trainer.py b/flux/flux/trainer.py index ed645941..40a126e8 100644 --- a/flux/flux/trainer.py +++ b/flux/flux/trainer.py @@ -26,15 +26,15 @@ class Trainer: # Random crop the input image between 0.8 to 1.0 of its original dimensions crop_size = ( max((0.8 + 0.2 * a) * width, resolution[0]), - max((0.8 + 0.2 * a) * height, resolution[1]), + max((0.8 + 0.2 * b) * height, resolution[1]), ) pan = (width - crop_size[0], height - crop_size[1]) img = img.crop( ( - pan[0] * b, - pan[1] * c, - crop_size[0] + pan[0] * b, - crop_size[1] + pan[1] * c, + pan[0] * c, + pan[1] * d, + crop_size[0] + pan[0] * c, + crop_size[1] + pan[1] * d, ) )