mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
FLUX: ref dataset args
This commit is contained in:
parent
7a20389c06
commit
aed4b007fc
@ -67,7 +67,7 @@ def setup_arg_parser():
|
|||||||
|
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--model",
|
"--model",
|
||||||
default="schnell",
|
default="dev",
|
||||||
choices=[
|
choices=[
|
||||||
"dev",
|
"dev",
|
||||||
"schnell",
|
"schnell",
|
||||||
@ -247,8 +247,7 @@ if __name__ == "__main__":
|
|||||||
x, t5_feat, clip_feat, guidance, prev_grads
|
x, t5_feat, clip_feat, guidance, prev_grads
|
||||||
)
|
)
|
||||||
|
|
||||||
# print("Create the training dataset.", flush=True)
|
dataset = load_dataset(args.dataset)
|
||||||
dataset = load_dataset(flux, args)
|
|
||||||
trainer = Trainer(flux, dataset, args)
|
trainer = Trainer(flux, dataset, args)
|
||||||
trainer.encode_dataset()
|
trainer.encode_dataset()
|
||||||
|
|
||||||
|
@ -5,19 +5,9 @@ from PIL import Image
|
|||||||
|
|
||||||
|
|
||||||
class Dataset:
|
class Dataset:
|
||||||
def __init__(self, flux, args, data):
|
def __init__(self, data):
|
||||||
self.args = args
|
|
||||||
self.flux = flux
|
|
||||||
|
|
||||||
self._data = data
|
self._data = data
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
|
||||||
item = self._data[index]
|
|
||||||
image = item["image"]
|
|
||||||
prompt = item["prompt"]
|
|
||||||
|
|
||||||
return image, prompt
|
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
if self._data is None:
|
if self._data is None:
|
||||||
return 0
|
return 0
|
||||||
@ -26,12 +16,12 @@ class Dataset:
|
|||||||
|
|
||||||
class LocalDataset(Dataset):
|
class LocalDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, flux, args, data_file):
|
def __init__(self, dataset: str, data_file):
|
||||||
self.dataset_base = Path(args.dataset)
|
self.dataset_base = Path(dataset)
|
||||||
with open(data_file, "r") as fid:
|
with open(data_file, "r") as fid:
|
||||||
self._data = [json.loads(l) for l in fid]
|
self._data = [json.loads(l) for l in fid]
|
||||||
|
|
||||||
super().__init__(flux, args, self._data)
|
super().__init__(self._data)
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
item = self._data[index]
|
item = self._data[index]
|
||||||
@ -41,29 +31,27 @@ class LocalDataset(Dataset):
|
|||||||
|
|
||||||
class HuggingFaceDataset(Dataset):
|
class HuggingFaceDataset(Dataset):
|
||||||
|
|
||||||
def __init__(self, flux, args):
|
def __init__(self, dataset: str):
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
df = load_dataset(args.dataset)["train"]
|
df = load_dataset(dataset)["train"]
|
||||||
self._data = df.data
|
self._data = df.data
|
||||||
super().__init__(flux, args, df)
|
super().__init__(df)
|
||||||
|
|
||||||
def __getitem__(self, index: int):
|
def __getitem__(self, index: int):
|
||||||
item = self._data[index]
|
item = self._data[index]
|
||||||
return item["image"], item["prompt"]
|
return item["image"], item["prompt"]
|
||||||
|
|
||||||
|
|
||||||
def load_dataset(flux, args):
|
def load_dataset(dataset: str):
|
||||||
dataset_base = Path(args.dataset)
|
dataset_base = Path(dataset)
|
||||||
data_file = dataset_base / "train.jsonl"
|
data_file = dataset_base / "train.jsonl"
|
||||||
|
|
||||||
if data_file.exists():
|
if data_file.exists():
|
||||||
print(f"Load the local dataset {data_file} .", flush=True)
|
print(f"Load the local dataset {data_file} .", flush=True)
|
||||||
# print(f"load local dataset: {data_file}")
|
dataset = LocalDataset(dataset, data_file)
|
||||||
dataset = LocalDataset(flux, args, data_file)
|
|
||||||
else:
|
else:
|
||||||
print(f"Load the Hugging Face dataset {args.dataset} .", flush=True)
|
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
||||||
# print(f"load Hugging Face dataset: {args.dataset}")
|
dataset = HuggingFaceDataset(dataset)
|
||||||
dataset = HuggingFaceDataset(flux, args)
|
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
Loading…
Reference in New Issue
Block a user