mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-30 10:56:38 +08:00
FLUX: ref dataset args
This commit is contained in:
parent
7a20389c06
commit
aed4b007fc
@ -67,7 +67,7 @@ def setup_arg_parser():
|
||||
|
||||
parser.add_argument(
|
||||
"--model",
|
||||
default="schnell",
|
||||
default="dev",
|
||||
choices=[
|
||||
"dev",
|
||||
"schnell",
|
||||
@ -247,8 +247,7 @@ if __name__ == "__main__":
|
||||
x, t5_feat, clip_feat, guidance, prev_grads
|
||||
)
|
||||
|
||||
# print("Create the training dataset.", flush=True)
|
||||
dataset = load_dataset(flux, args)
|
||||
dataset = load_dataset(args.dataset)
|
||||
trainer = Trainer(flux, dataset, args)
|
||||
trainer.encode_dataset()
|
||||
|
||||
|
@ -5,19 +5,9 @@ from PIL import Image
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __init__(self, flux, args, data):
|
||||
self.args = args
|
||||
self.flux = flux
|
||||
|
||||
def __init__(self, data):
|
||||
self._data = data
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
image = item["image"]
|
||||
prompt = item["prompt"]
|
||||
|
||||
return image, prompt
|
||||
|
||||
def __len__(self):
|
||||
if self._data is None:
|
||||
return 0
|
||||
@ -26,12 +16,12 @@ class Dataset:
|
||||
|
||||
class LocalDataset(Dataset):
|
||||
|
||||
def __init__(self, flux, args, data_file):
|
||||
self.dataset_base = Path(args.dataset)
|
||||
def __init__(self, dataset: str, data_file):
|
||||
self.dataset_base = Path(dataset)
|
||||
with open(data_file, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
|
||||
super().__init__(flux, args, self._data)
|
||||
super().__init__(self._data)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
@ -41,29 +31,27 @@ class LocalDataset(Dataset):
|
||||
|
||||
class HuggingFaceDataset(Dataset):
|
||||
|
||||
def __init__(self, flux, args):
|
||||
def __init__(self, dataset: str):
|
||||
from datasets import load_dataset
|
||||
|
||||
df = load_dataset(args.dataset)["train"]
|
||||
df = load_dataset(dataset)["train"]
|
||||
self._data = df.data
|
||||
super().__init__(flux, args, df)
|
||||
super().__init__(df)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
return item["image"], item["prompt"]
|
||||
|
||||
|
||||
def load_dataset(flux, args):
|
||||
dataset_base = Path(args.dataset)
|
||||
def load_dataset(dataset: str):
|
||||
dataset_base = Path(dataset)
|
||||
data_file = dataset_base / "train.jsonl"
|
||||
|
||||
if data_file.exists():
|
||||
print(f"Load the local dataset {data_file} .", flush=True)
|
||||
# print(f"load local dataset: {data_file}")
|
||||
dataset = LocalDataset(flux, args, data_file)
|
||||
dataset = LocalDataset(dataset, data_file)
|
||||
else:
|
||||
print(f"Load the Hugging Face dataset {args.dataset} .", flush=True)
|
||||
# print(f"load Hugging Face dataset: {args.dataset}")
|
||||
dataset = HuggingFaceDataset(flux, args)
|
||||
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
||||
dataset = HuggingFaceDataset(dataset)
|
||||
|
||||
return dataset
|
||||
|
Loading…
Reference in New Issue
Block a user