mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-06-27 03:05:20 +08:00
99 lines
3.3 KiB
Python
99 lines
3.3 KiB
Python
import mlx.core as mx
|
|
import numpy as np
|
|
from PIL import Image, ImageFile
|
|
from tqdm import tqdm
|
|
|
|
from .datasets import Dataset
|
|
from .flux import FluxPipeline
|
|
|
|
|
|
class Trainer:
|
|
|
|
def __init__(self, flux: FluxPipeline, dataset: Dataset, args):
|
|
self.flux = flux
|
|
self.dataset = dataset
|
|
self.args = args
|
|
self.latents = []
|
|
self.t5_features = []
|
|
self.clip_features = []
|
|
|
|
def _random_crop_resize(self, img):
|
|
resolution = self.args.resolution
|
|
width, height = img.size
|
|
|
|
a, b, c, d = mx.random.uniform(shape=(4,), stream=mx.cpu).tolist()
|
|
|
|
# Random crop the input image between 0.8 to 1.0 of its original dimensions
|
|
crop_size = (
|
|
max((0.8 + 0.2 * a) * width, resolution[0]),
|
|
max((0.8 + 0.2 * b) * height, resolution[1]),
|
|
)
|
|
pan = (width - crop_size[0], height - crop_size[1])
|
|
img = img.crop(
|
|
(
|
|
pan[0] * c,
|
|
pan[1] * d,
|
|
crop_size[0] + pan[0] * c,
|
|
crop_size[1] + pan[1] * d,
|
|
)
|
|
)
|
|
|
|
# Fit the largest rectangle with the ratio of resolution in the image
|
|
# rectangle.
|
|
width, height = crop_size
|
|
ratio = resolution[0] / resolution[1]
|
|
r1 = (height * ratio, height)
|
|
r2 = (width, width / ratio)
|
|
r = r1 if r1[0] <= width else r2
|
|
img = img.crop(
|
|
(
|
|
(width - r[0]) / 2,
|
|
(height - r[1]) / 2,
|
|
(width + r[0]) / 2,
|
|
(height + r[1]) / 2,
|
|
)
|
|
)
|
|
|
|
# Finally resize the image to resolution
|
|
img = img.resize(resolution, Image.LANCZOS)
|
|
|
|
return mx.array(np.array(img))
|
|
|
|
def _encode_image(self, input_img: ImageFile.ImageFile, num_augmentations: int):
|
|
for i in range(num_augmentations):
|
|
img = self._random_crop_resize(input_img)
|
|
img = (img[:, :, :3].astype(self.flux.dtype) / 255) * 2 - 1
|
|
x_0 = self.flux.ae.encode(img[None])
|
|
x_0 = x_0.astype(self.flux.dtype)
|
|
mx.eval(x_0)
|
|
self.latents.append(x_0)
|
|
|
|
def _encode_prompt(self, prompt):
|
|
t5_tok, clip_tok = self.flux.tokenize([prompt])
|
|
t5_feat = self.flux.t5(t5_tok)
|
|
clip_feat = self.flux.clip(clip_tok).pooled_output
|
|
mx.eval(t5_feat, clip_feat)
|
|
self.t5_features.append(t5_feat)
|
|
self.clip_features.append(clip_feat)
|
|
|
|
def encode_dataset(self):
|
|
"""Encode the images & prompt in the latent space to prepare for training."""
|
|
self.flux.ae.eval()
|
|
for image, prompt in tqdm(self.dataset, desc="encode dataset"):
|
|
self._encode_image(image, self.args.num_augmentations)
|
|
self._encode_prompt(prompt)
|
|
|
|
def iterate(self, batch_size):
|
|
xs = mx.concatenate(self.latents)
|
|
t5 = mx.concatenate(self.t5_features)
|
|
clip = mx.concatenate(self.clip_features)
|
|
mx.eval(xs, t5, clip)
|
|
n_aug = self.args.num_augmentations
|
|
while True:
|
|
x_indices = mx.random.permutation(len(self.latents))
|
|
c_indices = x_indices // n_aug
|
|
for i in range(0, len(self.latents), batch_size):
|
|
x_i = x_indices[i : i + batch_size]
|
|
c_i = c_indices[i : i + batch_size]
|
|
yield xs[x_i], t5[c_i], clip[c_i]
|