mirror of
https://github.com/ml-explore/mlx-examples.git
synced 2025-08-31 11:54:37 +08:00
Update dataset
This commit is contained in:
parent
f2ccad52f4
commit
bb8436a441
@ -1,5 +1,7 @@
|
||||
import argparse
|
||||
import json
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
@ -16,6 +18,106 @@ from flux import FluxPipeline
|
||||
from flux.lora import LoRALinear
|
||||
|
||||
|
||||
@contextmanager
|
||||
def random_state(seed=None):
|
||||
s = mx.random.state[0]
|
||||
try:
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
yield
|
||||
finally:
|
||||
mx.random.state[0] = s
|
||||
|
||||
|
||||
class FinetuningDataset:
|
||||
def __init__(self, flux, args):
|
||||
self.args = args
|
||||
self.flux = flux
|
||||
self.dataset_base = Path(args.dataset)
|
||||
dataset_index = self.dataset_base / "index.json"
|
||||
if not dataset_index.exists():
|
||||
raise ValueError(f"'{args.dataset}' is not a valid finetuning dataset")
|
||||
with open(dataset_index, "r") as f:
|
||||
self.index = json.load(f)
|
||||
|
||||
self.latents = []
|
||||
self.t5_features = []
|
||||
self.clip_features = []
|
||||
|
||||
def encode_images(self):
|
||||
"""Encode the images in the latent space to prepare for training."""
|
||||
self.flux.ae.eval()
|
||||
for sample in tqdm(self.index["data"]):
|
||||
img = Image.open(self.dataset_base / sample["image"])
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1
|
||||
x_0 = self.flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(flux.dtype)
|
||||
mx.eval(x_0)
|
||||
self.latents.append(x_0)
|
||||
|
||||
def encode_prompts(self):
|
||||
"""Pre-encode the prompts so that we don't recompute them during
|
||||
training (doesn't allow finetuning the text encoders)."""
|
||||
for sample in tqdm(self.index["data"]):
|
||||
t5_tok, clip_tok = self.flux.tokenize([sample["text"]])
|
||||
t5_feat = self.flux.t5(t5_tok)
|
||||
clip_feat = self.flux.clip(clip_tok).pooled_output
|
||||
mx.eval(t5_feat, clip_feat)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def generate_prior_preservation(self):
|
||||
"""Generate some images and mix them with the training images to avoid
|
||||
overfitting to the dataset."""
|
||||
|
||||
prior_preservation = self.index.get("prior_preservation", None)
|
||||
if not prior_preservation:
|
||||
return
|
||||
|
||||
# Select a random set of prompts from the available ones
|
||||
prior_prompts = mx.random.randint(
|
||||
low=0,
|
||||
high=len(prior_preservation["prompts"]),
|
||||
shape=(prior_preservation["n_images"],),
|
||||
).tolist()
|
||||
|
||||
# For each prompt
|
||||
for prompt_idx in tqdm(prior_prompts):
|
||||
# Create the generator
|
||||
latents = self.flux.generate_latents(
|
||||
prior_preservation["prompts"][prompt_idx],
|
||||
num_steps=prior_preservation.get(
|
||||
"num_steps", 2 if "schnell" in self.flux.name else 35
|
||||
),
|
||||
)
|
||||
|
||||
# Extract the t5 and clip features
|
||||
conditioning = next(latents)
|
||||
mx.eval(conditioning)
|
||||
t5_feat = conditioning[2]
|
||||
clip_feat = conditioning[4]
|
||||
del conditioning
|
||||
|
||||
# Do the denoising
|
||||
for x_t in latents:
|
||||
mx.eval(x_t)
|
||||
|
||||
# Append everything in the data lists
|
||||
self.latents.append(x_t)
|
||||
self.t5_features.append(t5_feat)
|
||||
self.clip_features.append(clip_feat)
|
||||
|
||||
def iterate(self, batch_size):
|
||||
while True:
|
||||
indices = mx.random.randint(0, len(self.latents), (batch_size,)).tolist()
|
||||
x = mx.concatenate([self.latents[i] for i in indices])
|
||||
t5 = mx.concatenate([self.t5_features[i] for i in indices])
|
||||
clip = mx.concatenate([self.clip_features[i] for i in indices])
|
||||
mx.eval(x, t5, clip)
|
||||
yield x, t5, clip
|
||||
|
||||
|
||||
def linear_to_lora_layers(flux, args):
|
||||
lora_layers = []
|
||||
rank = args.lora_rank
|
||||
@ -27,20 +129,6 @@ def linear_to_lora_layers(flux, args):
|
||||
flux.flow.update_modules(tree_unflatten(lora_layers))
|
||||
|
||||
|
||||
def extract_latent_vectors(flux, image_folder):
|
||||
flux.ae.eval()
|
||||
latents = []
|
||||
for image in tqdm(Path(image_folder).iterdir()):
|
||||
img = Image.open(image)
|
||||
img = mx.array(np.array(img))
|
||||
img = (img[:, :, :3].astype(flux.dtype) / 255) * 2 - 1
|
||||
x_0 = flux.ae.encode(img[None])
|
||||
x_0 = x_0.astype(flux.dtype)
|
||||
mx.eval(x_0)
|
||||
latents.append(x_0)
|
||||
return mx.concatenate(latents)
|
||||
|
||||
|
||||
def decode_latents(flux, x):
|
||||
decoded = []
|
||||
for i in tqdm(range(len(x))):
|
||||
@ -50,32 +138,23 @@ def decode_latents(flux, x):
|
||||
|
||||
|
||||
def generate_latents(flux, n_images, prompt, steps, seed=None, leave=True):
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=n_images,
|
||||
num_steps=steps,
|
||||
seed=seed,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
|
||||
mx.eval(x_t)
|
||||
with random_state(seed):
|
||||
latents = flux.generate_latents(
|
||||
prompt,
|
||||
n_images=n_images,
|
||||
num_steps=steps,
|
||||
)
|
||||
for x_t in tqdm(latents, total=args.progress_steps, leave=leave):
|
||||
mx.eval(x_t)
|
||||
|
||||
return x_t
|
||||
|
||||
|
||||
def iterate_batches(t5_tokens, clip_tokens, x, batch_size):
|
||||
while True:
|
||||
indices = mx.random.randint(0, len(x), (batch_size,))
|
||||
t5_i = t5_tokens[indices]
|
||||
clip_i = clip_tokens[indices]
|
||||
x_i = x[indices]
|
||||
yield t5_i, clip_i, x_i
|
||||
return x_t
|
||||
|
||||
|
||||
def generate_progress_images(iteration, flux, args):
|
||||
out_dir = Path(args.output_dir)
|
||||
out_dir.mkdir(parents=True, exist_ok=True)
|
||||
out_file = out_dir / f"out_{iteration:03d}.png"
|
||||
print(f"Generating {str(out_file)}")
|
||||
print(f"Generating {str(out_file)}", flush=True)
|
||||
# Generate the latent vectors using diffusion
|
||||
n_images = 4
|
||||
latents = generate_latents(
|
||||
@ -118,7 +197,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument(
|
||||
"--iterations",
|
||||
type=int,
|
||||
default=400,
|
||||
default=1000,
|
||||
help="How many iterations to train for",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -129,6 +208,7 @@ if __name__ == "__main__":
|
||||
)
|
||||
parser.add_argument(
|
||||
"--progress-prompt",
|
||||
required=True,
|
||||
help="Use this prompt when generating images for evaluation",
|
||||
)
|
||||
parser.add_argument(
|
||||
@ -156,7 +236,7 @@ if __name__ == "__main__":
|
||||
"--warmup-steps", type=int, default=100, help="Learning rate warmup"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--learning-rate", type=float, default="1e-4", help="Learning rate for training"
|
||||
"--learning-rate", type=float, default="1e-5", help="Learning rate for training"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--grad-accumulate",
|
||||
@ -168,22 +248,24 @@ if __name__ == "__main__":
|
||||
"--output-dir", default="mlx_output", help="Folder to save the checkpoints in"
|
||||
)
|
||||
|
||||
parser.add_argument("prompt")
|
||||
parser.add_argument("image_folder")
|
||||
parser.add_argument("dataset")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.progress_prompt = args.progress_prompt or args.prompt
|
||||
# Initialize the seed but different per worker if we are in a distributed
|
||||
# setting.
|
||||
mx.random.seed(0xF0F0F0F0 + mx.distributed.init().rank())
|
||||
|
||||
flux = FluxPipeline("flux-" + args.model)
|
||||
flux.ensure_models_are_loaded()
|
||||
flux.flow.freeze()
|
||||
linear_to_lora_layers(flux, args)
|
||||
with random_state(0x0F0F0F0F):
|
||||
linear_to_lora_layers(flux, args)
|
||||
|
||||
trainable_params = tree_reduce(
|
||||
lambda acc, x: acc + x.size, flux.flow.trainable_parameters(), 0
|
||||
)
|
||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters")
|
||||
print(f"Training {trainable_params / 1024**2:.3f}M parameters", flush=True)
|
||||
|
||||
warmup = optim.linear_schedule(0, args.learning_rate, args.warmup_steps)
|
||||
cosine = optim.cosine_decay(
|
||||
@ -194,9 +276,9 @@ if __name__ == "__main__":
|
||||
state = [flux.flow.state, optimizer.state, mx.random.state]
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def single_step(t5_tokens, clip_tokens, x, guidance):
|
||||
def single_step(x, t5_feat, clip_feat, guidance):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
grads = average_gradients(grads)
|
||||
optimizer.update(flux.flow, grads)
|
||||
@ -204,25 +286,23 @@ if __name__ == "__main__":
|
||||
return loss
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance):
|
||||
def compute_loss_and_grads(t5_feat, clip_feat, x, guidance):
|
||||
return nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def compute_loss_and_accumulate_grads(
|
||||
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||
):
|
||||
def compute_loss_and_accumulate_grads(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||
return loss, grads
|
||||
|
||||
@partial(mx.compile, inputs=state, outputs=state)
|
||||
def grad_accumulate_and_step(t5_tokens, clip_tokens, x, guidance, prev_grads):
|
||||
def grad_accumulate_and_step(x, t5_feat, clip_feat, guidance, prev_grads):
|
||||
loss, grads = nn.value_and_grad(flux.flow, flux.training_loss)(
|
||||
t5_tokens, clip_tokens, x, guidance
|
||||
x, t5_feat, clip_feat, guidance
|
||||
)
|
||||
grads = tree_map(lambda a, b: a + b, prev_grads, grads)
|
||||
grads = average_gradients(grads)
|
||||
@ -230,28 +310,30 @@ if __name__ == "__main__":
|
||||
|
||||
return loss
|
||||
|
||||
def step(t5_tokens, clip_tokens, x, guidance, prev_grads, perform_step):
|
||||
def step(x, t5_feat, clip_feat, guidance, prev_grads, perform_step):
|
||||
if prev_grads is None:
|
||||
if perform_step:
|
||||
return single_step(t5_tokens, clip_tokens, x, guidance), None
|
||||
return single_step(x, t5_feat, clip_feat, guidance), None
|
||||
else:
|
||||
return compute_loss_and_grads(t5_tokens, clip_tokens, x, guidance)
|
||||
return compute_loss_and_grads(x, t5_feat, clip_feat, guidance)
|
||||
else:
|
||||
if perform_step:
|
||||
return (
|
||||
grad_accumulate_and_step(
|
||||
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||
x, t5_feat, clip_feat, x, guidance, prev_grads
|
||||
),
|
||||
None,
|
||||
)
|
||||
else:
|
||||
return compute_loss_and_accumulate_grads(
|
||||
t5_tokens, clip_tokens, x, guidance, prev_grads
|
||||
x, t5_feat, clip_feat, guidance, prev_grads
|
||||
)
|
||||
|
||||
print("Encoding training images to latent space")
|
||||
x = extract_latent_vectors(flux, args.image_folder)
|
||||
t5_tokens, clip_tokens = flux.tokenize([args.prompt] * len(x))
|
||||
print("Create the training dataset.", flush=True)
|
||||
dataset = FinetuningDataset(flux, args)
|
||||
dataset.encode_images()
|
||||
dataset.encode_prompts()
|
||||
dataset.generate_prior_preservation()
|
||||
guidance = mx.full((args.batch_size,), 4.0, dtype=flux.dtype)
|
||||
|
||||
# An initial generation to compare
|
||||
@ -260,8 +342,7 @@ if __name__ == "__main__":
|
||||
grads = None
|
||||
losses = []
|
||||
tic = time.time()
|
||||
batches = iterate_batches(t5_tokens, clip_tokens, x, args.batch_size)
|
||||
for i, batch in zip(range(args.iterations), batches):
|
||||
for i, batch in zip(range(args.iterations), dataset.iterate(args.batch_size)):
|
||||
loss, grads = step(*batch, guidance, grads, (i + 1) % args.grad_accumulate == 0)
|
||||
mx.eval(loss, grads, state)
|
||||
losses.append(loss.item())
|
||||
@ -272,7 +353,8 @@ if __name__ == "__main__":
|
||||
print(
|
||||
f"Iter: {i+1} Loss: {sum(losses) / 10:.3f} "
|
||||
f"It/s: {10 / (toc - tic):.3f} "
|
||||
f"Peak mem: {peak_mem:.3f} GB"
|
||||
f"Peak mem: {peak_mem:.3f} GB",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
if (i + 1) % args.progress_every == 0:
|
||||
|
@ -126,8 +126,8 @@ class FluxPipeline:
|
||||
seed=None,
|
||||
):
|
||||
# Set the PRNG state
|
||||
seed = int(time.time()) if seed is None else seed
|
||||
mx.random.seed(seed)
|
||||
if seed is not None:
|
||||
mx.random.seed(seed)
|
||||
|
||||
# Create the latent variables
|
||||
x_T = self.sampler.sample_prior((n_images, *latent_size, 16), dtype=self.dtype)
|
||||
@ -154,15 +154,15 @@ class FluxPipeline:
|
||||
|
||||
def training_loss(
|
||||
self,
|
||||
t5_tokens: mx.array,
|
||||
clip_tokens: mx.array,
|
||||
x_0: mx.array,
|
||||
t5_features: mx.array,
|
||||
clip_features: mx.array,
|
||||
guidance: mx.array,
|
||||
):
|
||||
# Get the text conditioning
|
||||
txt = self.t5(t5_tokens)
|
||||
txt_ids = mx.zeros(t5_tokens.shape + (3,), dtype=mx.int32)
|
||||
vec = self.clip(clip_tokens).pooled_output
|
||||
txt = t5_features
|
||||
txt_ids = mx.zeros(txt.shape[:-1] + (3,), dtype=mx.int32)
|
||||
vec = clip_features
|
||||
|
||||
# Prepare the latent input
|
||||
x_0, x_ids = self._prepare_latent_images(x_0)
|
||||
|
@ -118,8 +118,20 @@ class CLIPTokenizer:
|
||||
|
||||
|
||||
class T5Tokenizer:
|
||||
def __init__(self, model_file):
|
||||
def __init__(self, model_file, max_length=512):
|
||||
self._tokenizer = SentencePieceProcessor(model_file)
|
||||
self.max_length = max_length
|
||||
|
||||
@property
|
||||
def pad(self):
|
||||
try:
|
||||
return self._tokenizer.id_to_piece(self.pad_token)
|
||||
except IndexError:
|
||||
return None
|
||||
|
||||
@property
|
||||
def pad_token(self):
|
||||
return self._tokenizer.pad_id()
|
||||
|
||||
@property
|
||||
def bos(self):
|
||||
@ -143,9 +155,9 @@ class T5Tokenizer:
|
||||
def eos_token(self):
|
||||
return self._tokenizer.eos_id()
|
||||
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True):
|
||||
def tokenize(self, text, prepend_bos=True, append_eos=True, pad=True):
|
||||
if isinstance(text, list):
|
||||
return [self.tokenize(t, prepend_bos, append_eos) for t in text]
|
||||
return [self.tokenize(t, prepend_bos, append_eos, pad) for t in text]
|
||||
|
||||
tokens = self._tokenizer.encode(text)
|
||||
|
||||
@ -153,6 +165,8 @@ class T5Tokenizer:
|
||||
tokens = [self.bos_token] + tokens
|
||||
if append_eos and self.eos_token >= 0:
|
||||
tokens.append(self.eos_token)
|
||||
if len(tokens) < self.max_length and self.pad_token >= 0:
|
||||
tokens += [self.pad_token] * (self.max_length - len(tokens))
|
||||
|
||||
return tokens
|
||||
|
||||
|
@ -204,4 +204,4 @@ def load_clip_tokenizer(name: str):
|
||||
|
||||
def load_t5_tokenizer(name: str):
|
||||
model_file = hf_hub_download(configs[name].repo_id, "tokenizer_2/spiece.model")
|
||||
return T5Tokenizer(model_file)
|
||||
return T5Tokenizer(model_file, 256 if "schnell" in name else 512)
|
||||
|
Loading…
Reference in New Issue
Block a user