mlx-examples/flux/flux/trainer.py

99 lines
3.3 KiB
Python

import mlx.core as mx
import numpy as np
from PIL import Image, ImageFile
from tqdm import tqdm
from .datasets import Dataset
from .flux import FluxPipeline
class Trainer:
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
self.flux = flux
self.dataset = dataset
self.args = args
self.latents = []
self.t5_features = []
self.clip_features = []
def _random_crop_resize(self, img):
resolution = self.args.resolution
width, height = img.size
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
# Random crop the input image between 0.8 to 1.0 of its original dimensions
crop_size = (
max((0.8 + 0.2 * a) * width, resolution[0]),
max((0.8 + 0.2 * b) * height, resolution[1]),
)
pan = (width - crop_size[0], height - crop_size[1])
img = img.crop(
(
pan[0] * c,
pan[1] * d,
crop_size[0] + pan[0] * c,
crop_size[1] + pan[1] * d,
)
)
# Fit the largest rectangle with the ratio of resolution in the image
# rectangle.
width, height = crop_size
ratio = resolution[0] / resolution[1]
r1 = (height * ratio, height)
r2 = (width, width / ratio)
r = r1 if r1[0] <= width else r2
img = img.crop(
(
(width - r[0]) / 2,
(height - r[1]) / 2,
(width + r[0]) / 2,
(height + r[1]) / 2,
)
)
# Finally resize the image to resolution
img = img.resize(resolution, Image.LANCZOS)
return mx.array(np.array(img))
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
for i in range(num_augmentations):
img = self._random_crop_resize(input_img)
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
x_0 = self.flux.ae.encode(img[None])
x_0 = x_0.astype(self.flux.dtype)
mx.eval(x_0)
self.latents.append(x_0)
def _encode_prompt(self, prompt):
t5_tok, clip_tok = self.flux.tokenize([prompt])
t5_feat = self.flux.t5(t5_tok)
clip_feat = self.flux.clip(clip_tok).pooled_output
mx.eval(t5_feat, clip_feat)
self.t5_features.append(t5_feat)
self.clip_features.append(clip_feat)
def encode_dataset(self):
"""Encode the images & prompt in the latent space to prepare for training."""
self.flux.ae.eval()
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
self._encode_image(image, self.args.num_augmentations)
self._encode_prompt(prompt)
def iterate(self, batch_size):
xs = mx.concatenate(self.latents)
t5 = mx.concatenate(self.t5_features)
clip = mx.concatenate(self.clip_features)
mx.eval(xs, t5, clip)
n_aug = self.args.num_augmentations
while True:
x_indices = mx.random.permutation(len(self.latents))
c_indices = x_indices // n_aug
for i in range(0, len(self.latents), batch_size):
x_i = x_indices[i : i + batch_size]
c_i = c_indices[i : i + batch_size]
yield xs[x_i], t5[c_i], clip[c_i]