Minor changes

This commit is contained in:
Angelos Katharopoulos 2024-10-14 11:39:12 -07:00
parent 68518a3194
commit 532e961f58
3 changed files with 37 additions and 18 deletions

View File

@ -21,8 +21,9 @@ The dependencies are minimal, namely:
- `huggingface-hub` to download the checkpoints.
- `regex` for the tokenization
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
- `tqdm`, `PIL`, and `numpy` for the scripts
- `sentencepiece` for the T5 tokenizer
- `datasets` for using an HF dataset directly
You can install all of the above with the `requirements.txt` as follows:

View File

@ -5,51 +5,69 @@ from PIL import Image
class Dataset:
def __init__(self, data):
self._data = data
def __getitem__(self, index: int):
raise NotImplementedError()
def __len__(self):
if self._data is None:
return 0
return len(self._data)
raise NotImplementedError()
class LocalDataset(Dataset):
prompt_key = "prompt"
def __init__(self, dataset: str, data_file):
self.dataset_base = Path(dataset)
with open(data_file, "r") as fid:
self._data = [json.loads(l) for l in fid]
super().__init__(self._data)
def __len__(self):
return len(self._data)
def __getitem__(self, index: int):
item = self._data[index]
image = Image.open(self.dataset_base / item["image"])
return image, item["prompt"]
return image, item[self.prompt_key]
class LegacyDataset(LocalDataset):
prompt_key = "text"
def __init__(self, dataset: str):
self.dataset_base = Path(dataset)
with open(self.dataset_base / "index.json") as f:
self._data = json.load(f)["data"]
class HuggingFaceDataset(Dataset):
def __init__(self, dataset: str):
from datasets import load_dataset
from datasets import load_dataset as hf_load_dataset
df = load_dataset(dataset)["train"]
self._data = df.data
super().__init__(df)
self._df = hf_load_dataset(dataset)["train"]
def __len__(self):
return len(self._df)
def __getitem__(self, index: int):
item = self._data[index]
item = self._df[index]
return item["image"], item["prompt"]
def load_dataset(dataset: str):
dataset_base = Path(dataset)
data_file = dataset_base / "train.jsonl"
legacy_file = dataset_base / "index.json"
if data_file.exists():
print(f"Load the local dataset {data_file} .", flush=True)
dataset = LocalDataset(dataset, data_file)
elif legacy_file.exists():
print(f"Load the local dataset {legacy_file} .")
print()
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
print(" See the README for details.")
print(flush=True)
dataset = LegacyDataset(dataset)
else:
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
dataset = HuggingFaceDataset(dataset)

View File

@ -26,15 +26,15 @@ class Trainer:
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * a) * height, resolution[1]),
max((0.8 + 0.2 * b) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * b,
pan[1] * c,
crop_size[0] + pan[0] * b,
crop_size[1] + pan[1] * c,
pan[0] * c,
pan[1] * d,
crop_size[0] + pan[0] * c,
crop_size[1] + pan[1] * d,
)
)