mlx-examples/flux/flux/datasets.py

76 lines
2.0 KiB
Python

import json
from pathlib import Path
from PIL import Image
class Dataset:
def __getitem__(self, index: int):
raise NotImplementedError()
def __len__(self):
raise NotImplementedError()
class LocalDataset(Dataset):
prompt_key = "prompt"
def __init__(self, dataset: str, data_file):
self.dataset_base = Path(dataset)
with open(data_file, "r") as fid:
self._data = [json.loads(l) for l in fid]
def __len__(self):
return len(self._data)
def __getitem__(self, index: int):
item = self._data[index]
image = Image.open(self.dataset_base / item["image"])
return image, item[self.prompt_key]
class LegacyDataset(LocalDataset):
prompt_key = "text"
def __init__(self, dataset: str):
self.dataset_base = Path(dataset)
with open(self.dataset_base / "index.json") as f:
self._data = json.load(f)["data"]
class HuggingFaceDataset(Dataset):
def __init__(self, dataset: str):
from datasets import load_dataset as hf_load_dataset
self._df = hf_load_dataset(dataset)["train"]
def __len__(self):
return len(self._df)
def __getitem__(self, index: int):
item = self._df[index]
return item["image"], item["prompt"]
def load_dataset(dataset: str):
dataset_base = Path(dataset)
data_file = dataset_base / "train.jsonl"
legacy_file = dataset_base / "index.json"
if data_file.exists():
print(f"Load the local dataset {data_file} .", flush=True)
dataset = LocalDataset(dataset, data_file)
elif legacy_file.exists():
print(f"Load the local dataset {legacy_file} .")
print()
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
print(" See the README for details.")
print(flush=True)
dataset = LegacyDataset(dataset)
else:
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
dataset = HuggingFaceDataset(dataset)
return dataset