mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-09-17 09:08:10 +08:00
FLUX: Optimize dataset loading logic (#1038)
This commit is contained in:
@@ -1,16 +1,10 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
import math
|
||||
import time
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from .datasets import Dataset, load_dataset
|
||||
from .flux import FluxPipeline
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .trainer import Trainer
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
@@ -19,230 +13,3 @@ from .utils import (
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipeline:
|
||||
def __init__(self, name: str, t5_padding: bool = True):
|
||||
self.dtype = mx.bfloat16
|
||||
self.name = name
|
||||
self.t5_padding = t5_padding
|
||||
|
||||
self.ae = load_ae(name)
|
||||
self.flow = load_flow_model(name)
|
||||
self.clip = load_clip(name)
|
||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||
self.t5 = load_t5(name)
|
||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||
self.sampler = FluxSampler(name)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(
|
||||
self.ae.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.clip.parameters(),
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def reload_text_encoders(self):
|
||||
self.t5 = load_t5(self.name)
|
||||
self.clip = load_clip(self.name)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
return t5_tokens, clip_tokens
|
||||
|
||||
def _prepare_latent_images(self, x):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
# Pack the latent image to 2x2 patches
|
||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||
|
||||
# Create positions ids used to positionally encode each patch. Due to
|
||||
# the way RoPE works, this results in an interesting positional
|
||||
# encoding where parts of the feature are holding different positional
|
||||
# information. Namely, the first part holds information independent of
|
||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
||||
# information and the last one horizontal.
|
||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
||||
x_ids = mx.stack([i, j, k], axis=-1)
|
||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
||||
|
||||
return x, x_ids
|
||||
|
||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
||||
# Prepare the text features
|
||||
txt = self.t5(t5_tokens)
|
||||
if len(txt) == 1 and n_images > 1:
|
||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
if len(vec) == 1 and n_images > 1:
|
||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||
|
||||
return txt, txt_ids, vec
|
||||
|
||||
def _denoising_loop(
|
||||
self,
|
||||
x_t,
|
||||
x_ids,
|
||||
txt,
|
||||
txt_ids,
|
||||
vec,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
start: float = 1,
|
||||
stop: float = 0,
|
||||
):
|
||||
B = len(x_t)
|
||||
|
||||
def scalar(x):
|
||||
return mx.full((B,), x, dtype=self.dtype)
|
||||
|
||||
guidance = scalar(guidance)
|
||||
timesteps = self.sampler.timesteps(
|
||||
num_steps,
|
||||
x_t.shape[1],
|
||||
start=start,
|
||||
stop=stop,
|
||||
)
|
||||
for i in range(num_steps):
|
||||
t = timesteps[i]
|
||||
t_prev = timesteps[i + 1]
|
||||
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=scalar(t),
|
||||
guidance=guidance,
|
||||
)
|
||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
||||
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||
|
||||
# Get the conditioning
|
||||
t5_tokens, clip_tokens = self.tokenize(text)
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
||||
|
||||
# Yield the conditioning for controlled evaluation by the caller
|
||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
||||
|
||||
# Yield the latent sequences from the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||
)
|
||||
|
||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
||||
h, w = latent_size
|
||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||
x = self.ae.decode(x)
|
||||
return mx.clip(x + 1, 0, 2) * 0.5
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
reload_text_encoders: bool = True,
|
||||
progress: bool = True,
|
||||
):
|
||||
latents = self.generate_latents(
|
||||
text, n_images, num_steps, guidance, latent_size, seed
|
||||
)
|
||||
mx.eval(next(latents))
|
||||
|
||||
if reload_text_encoders:
|
||||
self.reload_text_encoders()
|
||||
|
||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
||||
mx.eval(x_t)
|
||||
|
||||
images = []
|
||||
for i in tqdm(range(len(x_t)), disable=not progress):
|
||||
images.append(self.decode(x_t[i : i + 1]))
|
||||
mx.eval(images[-1])
|
||||
images = mx.concatenate(images, axis=0)
|
||||
mx.eval(images)
|
||||
|
||||
return images
|
||||
|
||||
def training_loss(
|
||||
self,
|
||||
x_0: mx.array,
|
||||
t5_features: mx.array,
|
||||
clip_features: mx.array,
|
||||
guidance: mx.array,
|
||||
):
|
||||
# Get the text conditioning
|
||||
txt = t5_features
|
||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
||||
vec = clip_features
|
||||
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process
|
||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
x_t = mx.stop_gradient(x_t)
|
||||
|
||||
# Do the denoising
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
return (pred + x_0 - eps).square().mean()
|
||||
|
||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
def fuse_lora_layers(self):
|
||||
fused_layers = []
|
||||
for name, module in self.flow.named_modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
fused_layers.append((name, module.fuse()))
|
||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
||||
|
75
flux/flux/datasets.py
Normal file
75
flux/flux/datasets.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class Dataset:
|
||||
def __getitem__(self, index: int):
|
||||
raise NotImplementedError()
|
||||
|
||||
def __len__(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class LocalDataset(Dataset):
|
||||
prompt_key = "prompt"
|
||||
|
||||
def __init__(self, dataset: str, data_file):
|
||||
self.dataset_base = Path(dataset)
|
||||
with open(data_file, "r") as fid:
|
||||
self._data = [json.loads(l) for l in fid]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._data)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._data[index]
|
||||
image = Image.open(self.dataset_base / item["image"])
|
||||
return image, item[self.prompt_key]
|
||||
|
||||
|
||||
class LegacyDataset(LocalDataset):
|
||||
prompt_key = "text"
|
||||
|
||||
def __init__(self, dataset: str):
|
||||
self.dataset_base = Path(dataset)
|
||||
with open(self.dataset_base / "index.json") as f:
|
||||
self._data = json.load(f)["data"]
|
||||
|
||||
|
||||
class HuggingFaceDataset(Dataset):
|
||||
|
||||
def __init__(self, dataset: str):
|
||||
from datasets import load_dataset as hf_load_dataset
|
||||
|
||||
self._df = hf_load_dataset(dataset)["train"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self._df)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
item = self._df[index]
|
||||
return item["image"], item["prompt"]
|
||||
|
||||
|
||||
def load_dataset(dataset: str):
|
||||
dataset_base = Path(dataset)
|
||||
data_file = dataset_base / "train.jsonl"
|
||||
legacy_file = dataset_base / "index.json"
|
||||
|
||||
if data_file.exists():
|
||||
print(f"Load the local dataset {data_file} .", flush=True)
|
||||
dataset = LocalDataset(dataset, data_file)
|
||||
elif legacy_file.exists():
|
||||
print(f"Load the local dataset {legacy_file} .")
|
||||
print()
|
||||
print(" WARNING: 'index.json' is deprecated in favor of 'train.jsonl'.")
|
||||
print(" See the README for details.")
|
||||
print(flush=True)
|
||||
dataset = LegacyDataset(dataset)
|
||||
else:
|
||||
print(f"Load the Hugging Face dataset {dataset} .", flush=True)
|
||||
dataset = HuggingFaceDataset(dataset)
|
||||
|
||||
return dataset
|
246
flux/flux/flux.py
Normal file
246
flux/flux/flux.py
Normal file
@@ -0,0 +1,246 @@
|
||||
# Copyright © 2024 Apple Inc.
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import mlx.core as mx
|
||||
import mlx.nn as nn
|
||||
from mlx.utils import tree_unflatten
|
||||
from tqdm import tqdm
|
||||
|
||||
from .lora import LoRALinear
|
||||
from .sampler import FluxSampler
|
||||
from .utils import (
|
||||
load_ae,
|
||||
load_clip,
|
||||
load_clip_tokenizer,
|
||||
load_flow_model,
|
||||
load_t5,
|
||||
load_t5_tokenizer,
|
||||
)
|
||||
|
||||
|
||||
class FluxPipeline:
|
||||
def __init__(self, name: str, t5_padding: bool = True):
|
||||
self.dtype = mx.bfloat16
|
||||
self.name = name
|
||||
self.t5_padding = t5_padding
|
||||
|
||||
self.ae = load_ae(name)
|
||||
self.flow = load_flow_model(name)
|
||||
self.clip = load_clip(name)
|
||||
self.clip_tokenizer = load_clip_tokenizer(name)
|
||||
self.t5 = load_t5(name)
|
||||
self.t5_tokenizer = load_t5_tokenizer(name)
|
||||
self.sampler = FluxSampler(name)
|
||||
|
||||
def ensure_models_are_loaded(self):
|
||||
mx.eval(
|
||||
self.ae.parameters(),
|
||||
self.flow.parameters(),
|
||||
self.clip.parameters(),
|
||||
self.t5.parameters(),
|
||||
)
|
||||
|
||||
def reload_text_encoders(self):
|
||||
self.t5 = load_t5(self.name)
|
||||
self.clip = load_clip(self.name)
|
||||
|
||||
def tokenize(self, text):
|
||||
t5_tokens = self.t5_tokenizer.encode(text, pad=self.t5_padding)
|
||||
clip_tokens = self.clip_tokenizer.encode(text)
|
||||
return t5_tokens, clip_tokens
|
||||
|
||||
def _prepare_latent_images(self, x):
|
||||
b, h, w, c = x.shape
|
||||
|
||||
# Pack the latent image to 2x2 patches
|
||||
x = x.reshape(b, h // 2, 2, w // 2, 2, c)
|
||||
x = x.transpose(0, 1, 3, 5, 2, 4).reshape(b, h * w // 4, c * 4)
|
||||
|
||||
# Create positions ids used to positionally encode each patch. Due to
|
||||
# the way RoPE works, this results in an interesting positional
|
||||
# encoding where parts of the feature are holding different positional
|
||||
# information. Namely, the first part holds information independent of
|
||||
# the spatial position (hence 0s), the 2nd part holds vertical spatial
|
||||
# information and the last one horizontal.
|
||||
i = mx.zeros((h // 2, w // 2), dtype=mx.int32)
|
||||
j, k = mx.meshgrid(mx.arange(h // 2), mx.arange(w // 2), indexing="ij")
|
||||
x_ids = mx.stack([i, j, k], axis=-1)
|
||||
x_ids = mx.repeat(x_ids.reshape(1, h * w // 4, 3), b, 0)
|
||||
|
||||
return x, x_ids
|
||||
|
||||
def _prepare_conditioning(self, n_images, t5_tokens, clip_tokens):
|
||||
# Prepare the text features
|
||||
txt = self.t5(t5_tokens)
|
||||
if len(txt) == 1 and n_images > 1:
|
||||
txt = mx.broadcast_to(txt, (n_images, *txt.shape[1:]))
|
||||
txt_ids = mx.zeros((n_images, txt.shape[1], 3), dtype=mx.int32)
|
||||
|
||||
# Prepare the clip text features
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
if len(vec) == 1 and n_images > 1:
|
||||
vec = mx.broadcast_to(vec, (n_images, *vec.shape[1:]))
|
||||
|
||||
return txt, txt_ids, vec
|
||||
|
||||
def _denoising_loop(
|
||||
self,
|
||||
x_t,
|
||||
x_ids,
|
||||
txt,
|
||||
txt_ids,
|
||||
vec,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
start: float = 1,
|
||||
stop: float = 0,
|
||||
):
|
||||
B = len(x_t)
|
||||
|
||||
def scalar(x):
|
||||
return mx.full((B,), x, dtype=self.dtype)
|
||||
|
||||
guidance = scalar(guidance)
|
||||
timesteps = self.sampler.timesteps(
|
||||
num_steps,
|
||||
x_t.shape[1],
|
||||
start=start,
|
||||
stop=stop,
|
||||
)
|
||||
for i in range(num_steps):
|
||||
t = timesteps[i]
|
||||
t_prev = timesteps[i + 1]
|
||||
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=scalar(t),
|
||||
guidance=guidance,
|
||||
)
|
||||
x_t = self.sampler.step(pred, x_t, t, t_prev)
|
||||
|
||||
yield x_t
|
||||
|
||||
def generate_latents(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||
x_T, x_ids = self._prepare_latent_images(x_T)
|
||||
|
||||
# Get the conditioning
|
||||
t5_tokens, clip_tokens = self.tokenize(text)
|
||||
txt, txt_ids, vec = self._prepare_conditioning(n_images, t5_tokens, clip_tokens)
|
||||
|
||||
# Yield the conditioning for controlled evaluation by the caller
|
||||
yield (x_T, x_ids, txt, txt_ids, vec)
|
||||
|
||||
# Yield the latent sequences from the denoising loop
|
||||
yield from self._denoising_loop(
|
||||
x_T, x_ids, txt, txt_ids, vec, num_steps=num_steps, guidance=guidance
|
||||
)
|
||||
|
||||
def decode(self, x, latent_size: Tuple[int, int] = (64, 64)):
|
||||
h, w = latent_size
|
||||
x = x.reshape(len(x), h // 2, w // 2, -1, 2, 2)
|
||||
x = x.transpose(0, 1, 4, 2, 5, 3).reshape(len(x), h, w, -1)
|
||||
x = self.ae.decode(x)
|
||||
return mx.clip(x + 1, 0, 2) * 0.5
|
||||
|
||||
def generate_images(
|
||||
self,
|
||||
text: str,
|
||||
n_images: int = 1,
|
||||
num_steps: int = 35,
|
||||
guidance: float = 4.0,
|
||||
latent_size: Tuple[int, int] = (64, 64),
|
||||
seed=None,
|
||||
reload_text_encoders: bool = True,
|
||||
progress: bool = True,
|
||||
):
|
||||
latents = self.generate_latents(
|
||||
text, n_images, num_steps, guidance, latent_size, seed
|
||||
)
|
||||
mx.eval(next(latents))
|
||||
|
||||
if reload_text_encoders:
|
||||
self.reload_text_encoders()
|
||||
|
||||
for x_t in tqdm(latents, total=num_steps, disable=not progress, leave=True):
|
||||
mx.eval(x_t)
|
||||
|
||||
images = []
|
||||
for i in tqdm(range(len(x_t)), disable=not progress, desc="generate images"):
|
||||
images.append(self.decode(x_t[i : i + 1]))
|
||||
mx.eval(images[-1])
|
||||
images = mx.concatenate(images, axis=0)
|
||||
mx.eval(images)
|
||||
|
||||
return images
|
||||
|
||||
def training_loss(
|
||||
self,
|
||||
x_0: mx.array,
|
||||
t5_features: mx.array,
|
||||
clip_features: mx.array,
|
||||
guidance: mx.array,
|
||||
):
|
||||
# Get the text conditioning
|
||||
txt = t5_features
|
||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
||||
vec = clip_features
|
||||
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
||||
# Forward process
|
||||
t = self.sampler.random_timesteps(*x_0.shape[:2], dtype=self.dtype)
|
||||
eps = mx.random.normal(x_0.shape, dtype=self.dtype)
|
||||
x_t = self.sampler.add_noise(x_0, t, noise=eps)
|
||||
x_t = mx.stop_gradient(x_t)
|
||||
|
||||
# Do the denoising
|
||||
pred = self.flow(
|
||||
img=x_t,
|
||||
img_ids=x_ids,
|
||||
txt=txt,
|
||||
txt_ids=txt_ids,
|
||||
y=vec,
|
||||
timesteps=t,
|
||||
guidance=guidance,
|
||||
)
|
||||
|
||||
return (pred + x_0 - eps).square().mean()
|
||||
|
||||
def linear_to_lora_layers(self, rank: int = 8, num_blocks: int = -1):
|
||||
"""Swap the linear layers in the transformer blocks with LoRA layers."""
|
||||
all_blocks = self.flow.double_blocks + self.flow.single_blocks
|
||||
all_blocks.reverse()
|
||||
num_blocks = num_blocks if num_blocks > 0 else len(all_blocks)
|
||||
for i, block in zip(range(num_blocks), all_blocks):
|
||||
loras = []
|
||||
for name, module in block.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
loras.append((name, LoRALinear.from_base(module, r=rank)))
|
||||
block.update_modules(tree_unflatten(loras))
|
||||
|
||||
def fuse_lora_layers(self):
|
||||
fused_layers = []
|
||||
for name, module in self.flow.named_modules():
|
||||
if isinstance(module, LoRALinear):
|
||||
fused_layers.append((name, module.fuse()))
|
||||
self.flow.update_modules(tree_unflatten(fused_layers))
|
98
flux/flux/trainer.py
Normal file
98
flux/flux/trainer.py
Normal file
@@ -0,0 +1,98 @@
|
||||
import mlx.core as mx
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from tqdm import tqdm
|
||||
|
||||
from .datasets import Dataset
|
||||
from .flux import FluxPipeline
|
||||
|
||||
|
||||
class Trainer:
|
||||
|
||||
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
||||
self.flux = flux
|
||||
self.dataset = dataset
|
||||
self.args = args
|
||||
self.latents = []
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
|
||||
def _random_crop_resize(self, img):
|
||||
resolution = self.args.resolution
|
||||
width, height = img.size
|
||||
|
||||
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
||||
|
||||
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
||||
crop_size = (
|
||||
max((0.8 + 0.2 * a) * width, resolution[0]),
|
||||
max((0.8 + 0.2 * b) * height, resolution[1]),
|
||||
)
|
||||
pan = (width - crop_size[0], height - crop_size[1])
|
||||
img = img.crop(
|
||||
(
|
||||
pan[0] * c,
|
||||
pan[1] * d,
|
||||
crop_size[0] + pan[0] * c,
|
||||
crop_size[1] + pan[1] * d,
|
||||
)
|
||||
)
|
||||
|
||||
# Fit the largest rectangle with the ratio of resolution in the image
|
||||
# rectangle.
|
||||
width, height = crop_size
|
||||
ratio = resolution[0] / resolution[1]
|
||||
r1 = (height * ratio, height)
|
||||
r2 = (width, width / ratio)
|
||||
r = r1 if r1[0] <= width else r2
|
||||
img = img.crop(
|
||||
(
|
||||
(width - r[0]) / 2,
|
||||
(height - r[1]) / 2,
|
||||
(width + r[0]) / 2,
|
||||
(height + r[1]) / 2,
|
||||
)
|
||||
)
|
||||
|
||||
# Finally resize the image to resolution
|
||||
img = img.resize(resolution, Image.LANCZOS)
|
||||
|
||||
return mx.array(np.array(img))
|
||||
|
||||
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
||||
for i in range(num_augmentations):
|
||||
img = self._random_crop_resize(input_img)
|
||||
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(self.flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def _encode_prompt(self, prompt):
|
||||
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
||||
t5_feat = self.flux.t5(t5_tok)
|
||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||
mx.eval(t5_feat, clip_feat)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def encode_dataset(self):
|
||||
"""Encode the images & prompt in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
||||
self._encode_image(image, self.args.num_augmentations)
|
||||
self._encode_prompt(prompt)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
xs = mx.concatenate(self.latents)
|
||||
t5 = mx.concatenate(self.t5_features)
|
||||
clip = mx.concatenate(self.clip_features)
|
||||
mx.eval(xs, t5, clip)
|
||||
n_aug = self.args.num_augmentations
|
||||
while True:
|
||||
x_indices = mx.random.permutation(len(self.latents))
|
||||
c_indices = x_indices // n_aug
|
||||
for i in range(0, len(self.latents), batch_size):
|
||||
x_i = x_indices[i : i + batch_size]
|
||||
c_i = c_indices[i : i + batch_size]
|
||||
yield xs[x_i], t5[c_i], clip[c_i]
|
Reference in New Issue
Block a user