mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 02:53:41 +08:00
Minor changes
This commit is contained in:
parent
68518a3194
commit
532e961f58
@ -21,8 +21,9 @@ The dependencies are minimal, namely:
|
|||||||
|
|
||||||
- `huggingface-hub` to download the checkpoints.
|
- `huggingface-hub` to download the checkpoints.
|
||||||
- `regex` for the tokenization
|
- `regex` for the tokenization
|
||||||
- `tqdm`, `PIL`, and `numpy` for the `txt2image.py` script
|
- `tqdm`, `PIL`, and `numpy` for the scripts
|
||||||
- `sentencepiece` for the T5 tokenizer
|
- `sentencepiece` for the T5 tokenizer
|
||||||
|
- `datasets` for using an HF dataset directly
|
||||||
|
|
||||||
You can install all of the above with the `requirements.txt` as follows:
|
You can install all of the above with the `requirements.txt` as follows:
|
||||||
|
|
||||||
|
@ -5,51 +5,69 @@ from PIL import Image
|
|||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset:
|
||||||
def __init__(self, data):
|
def __getitem__(self, index: int):
|
||||||
self._data = data
|
raise NotImplementedError()
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self._data is None:
|
raise NotImplementedError()
|
||||||
return 0
|
|
||||||
return len(self._data)
|
|
||||||
|
|
||||||
|
|
||||||
class LocalDataset(Dataset):
|
class LocalDataset(Dataset):
|
||||||
|
prompt_key = "prompt"
|
||||||
|
|
||||||
def __init__(self, dataset: str, data_file):
|
def __init__(self, dataset: str, data_file):
|
||||||
self.dataset_base = Path(dataset)
|
self.dataset_base = Path(dataset)
|
||||||
with open(data_file, "r") as fid:
|
with open(data_file, "r") as fid:
|
||||||
self._data = [json.loads(l) for l in fid]
|
self._data = [json.loads(l) for l in fid]
|
||||||
|
|
||||||
super().__init__(self._data)
|
def __len__(self):
|
||||||
|
return len(self._data)
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
item = self._data[index]
|
item = self._data[index]
|
||||||
image = Image.open(self.dataset_base / item["image"])
|
image = Image.open(self.dataset_base / item["image"])
|
||||||
return image, item["prompt"]
|
return image, item[self.prompt_key]
|
||||||
|
|
||||||
|
|
||||||
|
class LegacyDataset(LocalDataset):
|
||||||
|
prompt_key = "text"
|
||||||
|
|
||||||
|
def __init__(self, dataset: str):
|
||||||
|
self.dataset_base = Path(dataset)
|
||||||
|
with open(self.dataset_base / "index.json") as f:
|
||||||
|
self._data = json.load(f)["data"]
|
||||||
|
|
||||||
|
|
||||||
class HuggingFaceDataset(Dataset):
|
class HuggingFaceDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, dataset: str):
|
def __init__(self, dataset: str):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset as hf_load_dataset
|
||||||
|
|
||||||
df = load_dataset(dataset)["train"]
|
self._df = hf_load_dataset(dataset)["train"]
|
||||||
self._data = df.data
|
|
||||||
super().__init__(df)
|
def __len__(self):
|
||||||
|
return len(self._df)
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
item = self._data[index]
|
item = self._df[index]
|
||||||
return item["image"], item["prompt"]
|
return item["image"], item["prompt"]
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(dataset: str):
|
def load_dataset(dataset: str):
|
||||||
dataset_base = Path(dataset)
|
dataset_base = Path(dataset)
|
||||||
data_file = dataset_base / "train.jsonl"
|
data_file = dataset_base / "train.jsonl"
|
||||||
|
legacy_file = dataset_base / "index.json"
|
||||||
|
|
||||||
if data_file.exists():
|
if data_file.exists():
|
||||||
print(f"Load the local dataset {data_file} .", flush=True)
|
print(f"Load the local dataset {data_file} .", flush=True)
|
||||||
dataset = LocalDataset(dataset, data_file)
|
dataset = LocalDataset(dataset, data_file)
|
||||||
|
elif legacy_file.exists():
|
||||||
|
print(f"Load the local dataset {legacy_file} .")
|
||||||
|
print()
|
||||||
|
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
|
||||||
|
print(" See the README for details.")
|
||||||
|
print(flush=True)
|
||||||
|
dataset = LegacyDataset(dataset)
|
||||||
else:
|
else:
|
||||||
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
||||||
dataset = HuggingFaceDataset(dataset)
|
dataset = HuggingFaceDataset(dataset)
|
||||||
|
@ -26,15 +26,15 @@ class Trainer:
|
|||||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||||
crop_size = (
|
crop_size = (
|
||||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||||
max((0.8 + 0.2 * a) * height, resolution[1]),
|
max((0.8 + 0.2 * b) * height, resolution[1]),
|
||||||
)
|
)
|
||||||
pan = (width - crop_size[0], height - crop_size[1])
|
pan = (width - crop_size[0], height - crop_size[1])
|
||||||
img = img.crop(
|
img = img.crop(
|
||||||
(
|
(
|
||||||
pan[0] * b,
|
pan[0] * c,
|
||||||
pan[1] * c,
|
pan[1] * d,
|
||||||
crop_size[0] + pan[0] * b,
|
crop_size[0] + pan[0] * c,
|
||||||
crop_size[1] + pan[1] * c,
|
crop_size[1] + pan[1] * d,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user